Commit ab122dac authored by yuguo's avatar yuguo
Browse files

[DCU] compile pass

parent 4c6a5a27
...@@ -126,6 +126,84 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons ...@@ -126,6 +126,84 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
} }
} }
#ifdef __HIP_PLATFORM_AMD__
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_col_scaling_kernel_int(const void* input, void* output, const int M,
const int K) {
constexpr int N_TILE_PER_TD = sizeof(int) / sizeof(int);
constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
// input is in M-major
constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4;
constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K;
const int M_i32 = M / 4;
const int K_i32 = K;
int m_tiles_in_tb = N_TILE_PER_TD;
int k_tiles_in_tb = TB_DIM;
if (blockIdx.x == gridDim.x - 1) {
k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1;
}
if (blockIdx.y == gridDim.y - 1) {
m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1;
}
const int32_t* input_i32 = reinterpret_cast<const int32_t*>(input) +
blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 +
blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32;
int32_t* output_i32[N_TILE_PER_TD];
#pragma unroll
for (int i = 0; i < m_tiles_in_tb; i++) {
output_i32[i] = reinterpret_cast<int32_t*>(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 +
(blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32;
}
extern __shared__ int slm[];
// load, global -> regs
int regs_vec[N_SF_PER_TD_PER_TILE];
if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 &&
threadIdx.y < k_tiles_in_tb) {
#pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = *reinterpret_cast<const int*>(
input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD);
}
// local shuffle
regs_shuffle_with_bit_shifts(regs_vec);
// store, regs -> shared
int tM = threadIdx.x * N_SF_PER_TD;
int* slm_tile = slm + (threadIdx.y * SF_TILE_SIZE_I32 +
tM / SF_TILE_DIM_M * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
for (int i = 0; i < N_SF_PER_TD; i++) {
/* TODO rotate_i */
slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 +
((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32] =
reinterpret_cast<int*>(regs_vec)[i];
}
}
__syncthreads();
// store, shared -> global
int linear_id = threadIdx.y * blockDim.x + threadIdx.x;
#pragma unroll
for (int i = 0; i < m_tiles_in_tb; i++) {
__align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32[i]);
__align__(16) int4* slm_v4i =
reinterpret_cast<int4*>(slm + i * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4;
j += blockDim.x * blockDim.y) {
output_v4i[j] = slm_v4i[j];
}
}
}
#endif
template <typename LType> template <typename LType>
__device__ inline void regs_shuffle(LType* regs_vec) { __device__ inline void regs_shuffle(LType* regs_vec) {
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
...@@ -196,6 +274,61 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons ...@@ -196,6 +274,61 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
} }
} }
#ifdef __HIP_PLATFORM_AMD__
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_row_scaling_kernel_int(const void* input, void* output, const int M,
const int K) {
constexpr int N_TILE_PER_TD = sizeof(int) / sizeof(int);
constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;
// input is in K-major
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M;
int n_tiles_in_tb = N_TILES_IN_TB;
const int K_i32 = K / 4;
if (blockIdx.x == gridDim.x - 1) {
n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1;
}
const int* input_i32 = reinterpret_cast<const int*>(input) +
blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB;
int* output_i32 = reinterpret_cast<int*>(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 +
blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32;
extern __shared__ int4 slm_v4i[];
// load, global -> regs
int regs_vec[N_SF_PER_TD_PER_TILE];
if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) {
#pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = *reinterpret_cast<const int*>(
input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD);
}
// shuffle regs
regs_shuffle<int>(regs_vec);
// store, regs -> shared
#pragma unroll
for (int i = 0; i < N_TILE_PER_TD; i++) {
/* TODO rotate i */
slm_v4i[(threadIdx.x * N_TILE_PER_TD + i) * SF_TILE_SIZE_I32 / 4 + threadIdx.y] =
reinterpret_cast<int4*>(regs_vec)[i];
}
}
__syncthreads();
// store, shared -> global
int linear_id = threadIdx.y * blockDim.x + threadIdx.x;
__align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32);
#pragma unroll
for (int i = linear_id; i < SF_TILE_SIZE_I32 * n_tiles_in_tb / 4; i += blockDim.x * blockDim.y) {
output_v4i[i] = slm_v4i[i];
}
}
#endif
} // namespace } // namespace
namespace transformer_engine { namespace transformer_engine {
...@@ -253,6 +386,29 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -253,6 +386,29 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
switch (vec_load_size) { switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__
case 4:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
output->scale_inv.dptr, m, k);
break;
case 2:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
output->scale_inv.dptr, m, k);
break;
case 1:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
output->scale_inv.dptr, m, k);
break;
#else
case 4: case 4:
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
...@@ -274,6 +430,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -274,6 +430,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
output->scale_inv.dptr, m, k); output->scale_inv.dptr, m, k);
break; break;
#endif
default: default:
NVTE_ERROR("Not valid vec_load_size."); NVTE_ERROR("Not valid vec_load_size.");
break; break;
...@@ -286,6 +443,29 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -286,6 +443,29 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size));
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
switch (vec_load_size) { switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__
case 4:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
break;
case 2:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
break;
case 1:
cudaFuncSetAttribute((const void *)swizzle_col_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel_int<SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
break;
#else
case 4: case 4:
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>, cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
...@@ -307,6 +487,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -307,6 +487,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
break; break;
#endif
default: default:
NVTE_ERROR("Not valid vec_load_size."); NVTE_ERROR("Not valid vec_load_size.");
break; break;
......
...@@ -170,7 +170,11 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], ...@@ -170,7 +170,11 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
CType elt = step_dbias.data.elt[j]; CType elt = step_dbias.data.elt[j];
#ifdef __HIP_PLATFORM_AMD__
elt = __shfl(elt, dbias_shfl_src_lane); // shuffle data in a warp
#else
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
#endif
out_dbias.data.elt[j] += elt; out_dbias.data.elt[j] += elt;
} }
} }
...@@ -484,6 +488,50 @@ static const char *ActTypeToString[] = { ...@@ -484,6 +488,50 @@ static const char *ActTypeToString[] = {
"dsrelu" // 12 "dsrelu" // 12
}; };
#ifdef __HIP_PLATFORM_AMD__
/* HIPCC has strict rules for __device__ functions usage on host.
It forbids not only calling but also other ODR-use assigning to variables
https://github.com/llvm/llvm-project/issues/105825
Use templated struct wrapper to work around
*/
template<typename ComputeType, typename ParamOP, ComputeType (*OP)(ComputeType, const ParamOP &)>
struct ActivationType
{
static constexpr auto op = OP;
};
template <typename ComputeType, typename ParamOP, ComputeType (*OP)(ComputeType, const ParamOP &)>
int get_activation_type() {
using act = ActivationType<ComputeType, ParamOP, OP>;
if (act::op == ActivationType<ComputeType, ParamOP, &sigmoid<ComputeType, ComputeType>>::op) {
return 1;
} else if (act::op == ActivationType<ComputeType, ParamOP, &dsigmoid<ComputeType, ComputeType>>::op) {
return 2;
} else if (act::op == ActivationType<ComputeType, ParamOP, &gelu<ComputeType, ComputeType>>::op) {
return 3;
} else if (act::op == ActivationType<ComputeType, ParamOP, &dgelu<ComputeType, ComputeType>>::op) {
return 4;
} else if (act::op == ActivationType<ComputeType, ParamOP, &qgelu<ComputeType, ComputeType>>::op) {
return 5;
} else if (act::op == ActivationType<ComputeType, ParamOP, &dqgelu<ComputeType, ComputeType>>::op) {
return 6;
} else if (act::op == ActivationType<ComputeType, ParamOP, &silu<ComputeType, ComputeType>>::op) {
return 7;
} else if (act::op == ActivationType<ComputeType, ParamOP, &dsilu<ComputeType, ComputeType>>::op) {
return 8;
} else if (act::op == ActivationType<ComputeType, ParamOP, &relu<ComputeType, ComputeType>>::op) {
return 9;
} else if (act::op == ActivationType<ComputeType, ParamOP, &drelu<ComputeType, ComputeType>>::op) {
return 10;
} else if (act::op == ActivationType<ComputeType, ParamOP, &srelu<ComputeType, ComputeType>>::op) {
return 11;
} else if (act::op == ActivationType<ComputeType, ParamOP, &dsrelu<ComputeType, ComputeType>>::op) {
return 12;
} else {
return 0;
}
}
#else
template <typename ComputeType, typename ParamOP, ComputeType (*OP)(ComputeType, const ParamOP &)> template <typename ComputeType, typename ParamOP, ComputeType (*OP)(ComputeType, const ParamOP &)>
constexpr int get_activation_type() { constexpr int get_activation_type() {
constexpr decltype(OP) ActivationList[] = { constexpr decltype(OP) ActivationList[] = {
...@@ -509,6 +557,7 @@ constexpr int get_activation_type() { ...@@ -509,6 +557,7 @@ constexpr int get_activation_type() {
} }
return 0; return 0;
} }
#endif
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ComputeType, typename ParamOP, template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP &)> ComputeType (*OP)(ComputeType, const ParamOP &)>
...@@ -734,11 +783,17 @@ void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor * ...@@ -734,11 +783,17 @@ void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
#ifdef __HIP_PLATFORM_AMD__
cudaFuncSetAttribute((const void *)
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType,
Param, nvec_in, nvec_out, Empty, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
#else
cudaFuncSetAttribute( cudaFuncSetAttribute(
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType,
Param, nvec_in, nvec_out, Empty, OP>, Param, nvec_in, nvec_out, Empty, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100);
#endif
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Param, cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Param,
nvec_in, nvec_out, Empty, OP> nvec_in, nvec_out, Empty, OP>
<<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>( <<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
...@@ -1195,10 +1250,17 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu ...@@ -1195,10 +1250,17 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile * const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>); (THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>);
if (full_tile) { if (full_tile) {
#ifdef __HIP_PLATFORM_AMD__
cudaFuncSetAttribute((const void *)
dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType,
OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
#else
cudaFuncSetAttribute( cudaFuncSetAttribute(
dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType, dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType,
OutputType, Empty, OP1, OP2>, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100);
#endif
dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType, OutputType, dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType, OutputType,
Empty, OP1, OP2> Empty, OP1, OP2>
...@@ -1212,10 +1274,17 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu ...@@ -1212,10 +1274,17 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
reinterpret_cast<fp32 *>(output->scale_inv.dptr), row_length, num_rows, reinterpret_cast<fp32 *>(output->scale_inv.dptr), row_length, num_rows,
n_tiles); n_tiles);
} else { } else {
#ifdef __HIP_PLATFORM_AMD__
cudaFuncSetAttribute((const void *)
dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType,
InputType, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
#else
cudaFuncSetAttribute( cudaFuncSetAttribute(
dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType, dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType,
InputType, OutputType, Empty, OP1, OP2>, InputType, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100);
#endif
dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType, InputType, dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType, InputType,
OutputType, Empty, OP1, OP2> OutputType, Empty, OP1, OP2>
<<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>( <<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
......
...@@ -90,7 +90,11 @@ inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_O ...@@ -90,7 +90,11 @@ inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_O
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) { for (unsigned int j = 0; j < NVEC_IN; ++j) {
CType elt = step_dbias.data.elt[j]; CType elt = step_dbias.data.elt[j];
#ifdef __HIP_PLATFORM_AMD__
elt = __shfl(elt, dbias_shfl_src_lane); // shuffle data in a warp
#else
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
#endif
out_dbias.data.elt[j] += elt; out_dbias.data.elt[j] += elt;
} }
} }
......
...@@ -45,7 +45,11 @@ inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out], ...@@ -45,7 +45,11 @@ inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out],
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
CType elt = step_dbias.data.elt[j]; CType elt = step_dbias.data.elt[j];
#ifdef __HIP_PLATFORM_AMD__
elt = __shfl(elt, dbias_shfl_src_lane); // shuffle data in a warp
#else
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in warp elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in warp
#endif
out_dbias.data.elt[j] += elt; out_dbias.data.elt[j] += elt;
} }
} }
...@@ -469,7 +473,21 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor ...@@ -469,7 +473,21 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor
param.scale_inv = param.scale_inv =
reinterpret_cast<const ComputeType *>(transposed_output->scale_inv.dptr); reinterpret_cast<const ComputeType *>(transposed_output->scale_inv.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr); param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
#ifdef __HIP_PLATFORM_AMD__
if (full_tile) {
cudaFuncSetAttribute((const void *)transpose_dbias_kernel<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
transpose_dbias_kernel<nvec_in, nvec_out, Param>
<<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute((const void *)transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>
<<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, n_tiles);
}
#else
if (full_tile) { if (full_tile) {
cudaFuncSetAttribute(transpose_dbias_kernel<nvec_in, nvec_out, Param>, cudaFuncSetAttribute(transpose_dbias_kernel<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100);
...@@ -483,7 +501,7 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor ...@@ -483,7 +501,7 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor
<<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>( <<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, n_tiles); param, row_length, num_rows, n_tiles);
} }
#endif
reduce_dbias<BiasType>(*workspace, dbias, row_length, num_rows, nvec_out, reduce_dbias<BiasType>(*workspace, dbias, row_length, num_rows, nvec_out,
stream);); // NOLINT(*) stream);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
......
...@@ -54,6 +54,7 @@ static_assert(ITERATIONS >= 1); ...@@ -54,6 +54,7 @@ static_assert(ITERATIONS >= 1);
__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } __device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); }
#ifndef __HIP_PLATFORM_AMD__
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType> float (*DActOP)(float, const ParamOP &), typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_CHUNK) __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...@@ -273,7 +274,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -273,7 +274,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} }
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
#endif // __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType, float (*DActOP)(float, const ParamOP &), typename IType, typename OType,
size_t SCALE_DIM_Y, size_t SCALE_DIM_X> size_t SCALE_DIM_Y, size_t SCALE_DIM_X>
...@@ -720,14 +723,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -720,14 +723,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} }
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
#endif // __HIP_PLATFORM_AMD__
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
cudaStream_t stream) { cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
static_assert(false, assert(false);
"Cast_fp8_gated is not surpported in rocm yet.");
#else #else
if (output->has_data()) { if (output->has_data()) {
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
...@@ -810,8 +813,7 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP ...@@ -810,8 +813,7 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
cudaStream_t stream) { cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
static_assert(false, assert(false);
"Cast_mxfp8_gated is not surpported in rocm yet.");
#else #else
const bool USE_ROWWISE_SCALING = output->has_data(); const bool USE_ROWWISE_SCALING = output->has_data();
const bool USE_COLWISE_SCALING = output->has_columnwise_data(); const bool USE_COLWISE_SCALING = output->has_columnwise_data();
......
...@@ -56,6 +56,7 @@ constexpr size_t MXFP8_BUFF_STAGES_NUM = ...@@ -56,6 +56,7 @@ constexpr size_t MXFP8_BUFF_STAGES_NUM =
constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32
static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM); static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM);
#ifndef __HIP_PLATFORM_AMD__
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP, template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, size_t SCALE_DIM_Y, float (*OP)(float, const ParamOP &), typename IType, typename OType, size_t SCALE_DIM_Y,
size_t SCALE_DIM_X> size_t SCALE_DIM_X>
...@@ -462,6 +463,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) ...@@ -462,6 +463,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
destroy_barriers<MXFP8_ITERATIONS>(mbar, is_master_thread); destroy_barriers<MXFP8_ITERATIONS>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
#endif // __HIP_PLATFORM_AMD__
constexpr size_t FP8_CHUNK_DIM_Y = 128; constexpr size_t FP8_CHUNK_DIM_Y = 128;
constexpr size_t FP8_CHUNK_DIM_X = 128; constexpr size_t FP8_CHUNK_DIM_X = 128;
...@@ -479,6 +481,7 @@ constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 ...@@ -479,6 +481,7 @@ constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16
constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16
static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM);
#ifndef __HIP_PLATFORM_AMD__
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &), template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, typename OType> typename IType, typename OType>
__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
...@@ -656,6 +659,7 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -656,6 +659,7 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
destroy_barriers<FP8_ITERATIONS>(mbar, is_master_thread); destroy_barriers<FP8_ITERATIONS>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
#endif // __HIP_PLATFORM_AMD__
constexpr size_t CHUNKS_PER_BLOCK = 128; constexpr size_t CHUNKS_PER_BLOCK = 128;
constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK;
...@@ -856,8 +860,7 @@ template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, cons ...@@ -856,8 +860,7 @@ template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, cons
void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias,
Tensor *workspace, cudaStream_t stream) { Tensor *workspace, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
static_assert(false, assert(false);
"Cast_fp8_2D is not surpported in rocm yet.");
#else #else
checkCuDriverContext(stream); checkCuDriverContext(stream);
...@@ -931,8 +934,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -931,8 +934,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const Tensor *noop, // TODO (ksivamani) const Tensor *noop, // TODO (ksivamani)
Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
static_assert(false, assert(false);
"Mxfp8_quantize is not surpported in rocm yet.");
#else #else
bool use_rowwise_scaling = output->has_data(); bool use_rowwise_scaling = output->has_data();
bool use_colwise_scaling = output->has_columnwise_data(); bool use_colwise_scaling = output->has_columnwise_data();
...@@ -1057,10 +1059,23 @@ __device__ inline float dequantize_func(float value, const DequantizeParam &para ...@@ -1057,10 +1059,23 @@ __device__ inline float dequantize_func(float value, const DequantizeParam &para
} // namespace detail } // namespace detail
#ifdef __HIP_PLATFORM_AMD__
template <typename ParamOP, float (*OP)(float, const ParamOP &)>
struct KernelType
{
static constexpr auto op = OP;
};
#endif
template <typename ParamOP, float (*OP)(float, const ParamOP &)> template <typename ParamOP, float (*OP)(float, const ParamOP &)>
void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output,
cudaStream_t stream) { cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
using kernel = KernelType<ParamOP, OP>;
constexpr float (*UnaryOP)(float, const ParamOP &) = (kernel::op == nullptr) ? KernelType<ParamOP, &detail::identity>::op : kernel::op;
#else
constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP;
#endif
const size_t N = product(input.data.shape); const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, input.data.dtype, IType,
...@@ -1084,7 +1099,12 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, ...@@ -1084,7 +1099,12 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop,
template <typename ParamOP, float (*OP)(float, const ParamOP &)> template <typename ParamOP, float (*OP)(float, const ParamOP &)>
void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output,
cudaStream_t stream) { cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
using kernel = KernelType<ParamOP, OP>;
constexpr float (*UnaryOP)(float, const ParamOP &) = (kernel::op == nullptr) ? KernelType<ParamOP, &detail::identity>::op : kernel::op;
#else
constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP;
#endif
const size_t N = product(input->data.shape); const size_t N = product(input->data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input->data.dtype, IType, input->data.dtype, IType,
......
...@@ -7,7 +7,11 @@ ...@@ -7,7 +7,11 @@
#include <filesystem> #include <filesystem>
#include "../common.h" #include "../common.h"
#ifdef USE_ROCM
#include "../util/hip_runtime.h"
#else
#include "../util/cuda_runtime.h" #include "../util/cuda_runtime.h"
#endif
namespace transformer_engine { namespace transformer_engine {
......
...@@ -4,7 +4,11 @@ ...@@ -4,7 +4,11 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#ifdef USE_ROCM
#include "hip_nvml.h"
#else
#include "cuda_nvml.h" #include "cuda_nvml.h"
#endif
#include "shared_lib_wrapper.h" #include "shared_lib_wrapper.h"
......
...@@ -10,9 +10,15 @@ ...@@ -10,9 +10,15 @@
#include <mutex> #include <mutex>
#include "../common.h" #include "../common.h"
#ifdef USE_ROCM
#include "../util/hip_driver.h"
#include "../util/system.h"
#include "common/util/hip_runtime.h"
#else
#include "../util/cuda_driver.h" #include "../util/cuda_driver.h"
#include "../util/system.h" #include "../util/system.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#endif
namespace transformer_engine { namespace transformer_engine {
......
...@@ -50,6 +50,7 @@ constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; ...@@ -50,6 +50,7 @@ constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X;
constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16 constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16
static_assert(ITERATIONS >= 1); static_assert(ITERATIONS >= 1);
#ifndef __HIP_PLATFORM_AMD__
template <typename IType, typename OType, size_t SCALE_DIM_Y, size_t SCALE_DIM_X> template <typename IType, typename OType, size_t SCALE_DIM_Y, size_t SCALE_DIM_X>
__global__ void __launch_bounds__(THREADS_PER_CHUNK) __global__ void __launch_bounds__(THREADS_PER_CHUNK)
dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
...@@ -229,6 +230,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -229,6 +230,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} }
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
#endif // __HIP_PLATFORM_AMD__
static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
...@@ -253,8 +255,7 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str ...@@ -253,8 +255,7 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str
static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
static_assert(false, assert(false);
"Mxfp8_dequantize is not surpported in rocm yet.");
#else #else
bool use_rowwise_scaling = input.has_data(); bool use_rowwise_scaling = input.has_data();
bool use_colwise_scaling = input.has_columnwise_data(); bool use_colwise_scaling = input.has_columnwise_data();
...@@ -337,8 +338,8 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s ...@@ -337,8 +338,8 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
}
#endif #endif
}
} // namespace dequantization } // namespace dequantization
namespace detail { namespace detail {
......
...@@ -9,7 +9,11 @@ ...@@ -9,7 +9,11 @@
#include <vector> #include <vector>
#ifdef __HIP_PLATFORM_AMD__
#include "hip_runtime.h"
#else
#include "cuda_runtime.h" #include "cuda_runtime.h"
#endif
#include "logging.h" #include "logging.h"
namespace transformer_engine::detail { namespace transformer_engine::detail {
......
...@@ -12,7 +12,11 @@ ...@@ -12,7 +12,11 @@
#include <transformer_engine/fused_attn.h> #include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#ifdef __HIP_PLATFORM_AMD__
#include "hip_runtime.h"
#else
#include "cuda_runtime.h" #include "cuda_runtime.h"
#endif
#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \
pybind11::enum_<transformer_engine::DType>(m, "DType", pybind11::module_local()) \ pybind11::enum_<transformer_engine::DType>(m, "DType", pybind11::module_local()) \
......
...@@ -11,7 +11,11 @@ ...@@ -11,7 +11,11 @@
#include <utility> #include <utility>
#include "../common.h" #include "../common.h"
#ifdef USE_ROCM
#include "../util/hip_driver.h"
#else
#include "../util/cuda_driver.h" #include "../util/cuda_driver.h"
#endif
#include "../util/string.h" #include "../util/string.h"
#include "../util/system.h" #include "../util/system.h"
......
...@@ -19,8 +19,13 @@ ...@@ -19,8 +19,13 @@
#include <vector> #include <vector>
#include "../common.h" #include "../common.h"
#ifdef __HIP_PLATFORM_AMD__
#include "../util/hip_driver.h"
#include "../util/hip_runtime.h"
#else
#include "../util/cuda_driver.h" #include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h" #include "../util/cuda_runtime.h"
#endif
namespace transformer_engine { namespace transformer_engine {
......
...@@ -258,7 +258,7 @@ struct Converter<float2, hip_bfloat16x2> { ...@@ -258,7 +258,7 @@ struct Converter<float2, hip_bfloat16x2> {
static inline __device__ hip_bfloat16x2 convert(const float2 &x) { static inline __device__ hip_bfloat16x2 convert(const float2 &x) {
union { union {
hip_bfloat16x2 raw; hip_bfloat16x2 raw;
hip_bfloat16 elt[2]; __hip_bfloat16 elt[2];
} tmp; } tmp;
tmp.elt[0] = __hip_bfloat16(x.x); tmp.elt[0] = __hip_bfloat16(x.x);
tmp.elt[1] = __hip_bfloat16(x.y); tmp.elt[1] = __hip_bfloat16(x.y);
...@@ -1020,6 +1020,13 @@ struct Quantized_Limits { ...@@ -1020,6 +1020,13 @@ struct Quantized_Limits {
static constexpr float emax_rcp = 1.0 / emax; static constexpr float emax_rcp = 1.0 / emax;
}; };
#ifdef __HIP_PLATFORM_AMD__
#define __CUDA_ARCH_HAS_FEATURE__(FEATURE) \
((__CUDA_ARCH__ >= 100 && FEATURE == SM100_ALL) || \
(__CUDA_ARCH__ >= 101 && FEATURE == SM101_ALL) || \
(__CUDA_ARCH__ >= 120 && FEATURE == SM120_ALL))
#endif
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { __device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
// TODO: nan/inf needs to be set for any value // TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax. // of nan/inf in input not just amax.
......
...@@ -108,7 +108,7 @@ std::optional<std::vector<at::Tensor>> te_general_batched_gemm( ...@@ -108,7 +108,7 @@ std::optional<std::vector<at::Tensor>> te_general_batched_gemm(
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias, std::vector<int64_t> m_splits, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out, transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate, bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) bool use_split_accumulator, int math_sm_count);
#endif #endif
/*************************************************************************************************** /***************************************************************************************************
......
...@@ -18,8 +18,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -18,8 +18,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) { int64_t window_size_right) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
static_assert(false, assert(false);
"Get_fused_attn_backend is not surpported in rocm for normalization yet.");
#else #else
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
...@@ -101,8 +100,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -101,8 +100,7 @@ std::vector<py::object> fused_attn_fwd(
py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) { const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
static_assert(false, assert(false);
"Fused_attn_fwd is not surpported in rocm for normalization yet.");
#else #else
using namespace transformer_engine; using namespace transformer_engine;
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
...@@ -294,8 +292,7 @@ std::vector<py::object> fused_attn_bwd( ...@@ -294,8 +292,7 @@ std::vector<py::object> fused_attn_bwd(
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer) { py::handle dp_quantizer, py::handle dqkv_quantizer) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
static_assert(false, assert(false);
"Fused_attn_bwd is not surpported in rocm for normalization yet.");
#else #else
using namespace transformer_engine; using namespace transformer_engine;
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
...@@ -1051,8 +1048,7 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t ...@@ -1051,8 +1048,7 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
template <typename scalar_t> template <typename scalar_t>
void convert_thd_to_bshd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens, void convert_thd_to_bshd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens,
int b, int max_seq_len, int h, int d) { int b, int max_seq_len, int h, int d) {
transformer_engine::fused_attn:: transformer_engine::fused_attn::convert_thd_to_bshd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
convert_thd_to_bshd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()), reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(), reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(),
b, max_seq_len, h, d); b, max_seq_len, h, d);
...@@ -1091,8 +1087,7 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, ...@@ -1091,8 +1087,7 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b,
template <typename scalar_t> template <typename scalar_t>
void convert_bshd_to_thd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens, void convert_bshd_to_thd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens,
int b, int max_seq_len, int h, int d) { int b, int max_seq_len, int h, int d) {
transformer_engine::fused_attn:: transformer_engine::fused_attn::convert_bshd_to_thd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
convert_bshd_to_thd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()), reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(), reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(),
b, max_seq_len, h, d); b, max_seq_len, h, d);
...@@ -1152,15 +1147,13 @@ void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_ ...@@ -1152,15 +1147,13 @@ void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_
if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr && if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr &&
v_cache.data_ptr() != nullptr) { v_cache.data_ptr() != nullptr) {
if (is_non_paged) { if (is_non_paged) {
transformer_engine::fused_attn:: transformer_engine::fused_attn::reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()), reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()), reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()),
page_table.data_ptr<int>(), cu_new_lens.data_ptr<int>(), page_table.data_ptr<int>(), cu_new_lens.data_ptr<int>(),
cu_cached_lens.data_ptr<int>(), h_kv, d_k, d_v, b, max_seq_len); cu_cached_lens.data_ptr<int>(), h_kv, d_k, d_v, b, max_seq_len);
} }
transformer_engine::fused_attn:: transformer_engine::fused_attn::copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(new_k.data_ptr<scalar_t>()), reinterpret_cast<scalar_t *>(new_k.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_v.data_ptr<scalar_t>()), reinterpret_cast<scalar_t *>(new_v.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()), reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
......
...@@ -12,7 +12,11 @@ ...@@ -12,7 +12,11 @@
#include "../common.h" #include "../common.h"
#include "common.h" #include "common.h"
#ifdef USE_ROCM
#include "common/util/hip_runtime.h"
#else
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#endif
#include "common/util/system.h" #include "common/util/system.h"
#include "extensions.h" #include "extensions.h"
#include "pybind.h" #include "pybind.h"
...@@ -531,9 +535,10 @@ std::optional<std::vector<at::Tensor>> te_general_batched_gemm( ...@@ -531,9 +535,10 @@ std::optional<std::vector<at::Tensor>> te_general_batched_gemm(
wrappers.emplace_back(std::move(wsp)); wrappers.emplace_back(std::move(wsp));
} }
// For now, we only have multi-stream cublas backend. // For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_batchgemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), nvte_multi_stream_cublas_batchgemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, te_bias_vector.data(), te_pre_gelu_out_vector.data(),
te_workspace.data(), accumulate, use_split_accumulator, te_A_vector.size(), transa, transb, grad,
te_workspace_vector.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream()); math_sm_count, at::cuda::getCurrentCUDAStream());
return bias; return bias;
} }
......
...@@ -306,7 +306,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor( ...@@ -306,7 +306,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{sinv0, sinv1}); std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
} }
if (columnwise_usage) { if (columnwise_usage) {
...@@ -317,7 +317,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor( ...@@ -317,7 +317,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{sinv0, sinv1}); std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
} }
this->set_quantization_params(&tensor); this->set_quantization_params(&tensor);
......
...@@ -92,7 +92,7 @@ def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]: ...@@ -92,7 +92,7 @@ def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
if not _multi_stream_cublas_batchgemm_workspace: if not _multi_stream_cublas_batchgemm_workspace:
for _ in range(tex._num_cublas_batchgemm_streams): for _ in range(tex._num_cublas_batchgemm_streams):
_multi_stream_cublas_batchgemm_workspace.append( _multi_stream_cublas_batchgemm_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda") torch.empty(128, dtype=torch.uint8, device="cuda")
) )
return _multi_stream_cublas_batchgemm_workspace return _multi_stream_cublas_batchgemm_workspace
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment