Commit fca88163 authored by wenjh's avatar wenjh
Browse files

[Perf] blockwise 1d better perf


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent ca1e98b6
...@@ -561,8 +561,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -561,8 +561,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kThreadTileCol; ++i) { for (int i = 0; i < kThreadTileCol; ++i) {
if constexpr (std::is_same_v<OType, int8_t>) { if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i] = static_cast<OType>(lroundf( output_vec.data.elt[i] = static_cast<OType>(lroundf(
fmaxf(-127.0f, fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(reg_vec[row_idx].data.elt[i]) *
fminf(127.0f, static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i])))); thr_scale.data.elt[i]))));
} else { } else {
output_vec.data.elt[i] = static_cast<OType>( output_vec.data.elt[i] = static_cast<OType>(
static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]); static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]);
...@@ -654,6 +654,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -654,6 +654,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
__syncthreads(); __syncthreads();
// If not return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return rowwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (!return_columnwise_gemm_ready && !return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 2: Cast and store to output_c // Step 2: Cast and store to output_c
if (return_rowwise) { if (return_rowwise) {
constexpr int r_stride = constexpr int r_stride =
...@@ -760,6 +768,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -760,6 +768,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
} }
} }
// If return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return columnwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (return_columnwise_gemm_ready || return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 3: Transpose, cast and store to output_t // Step 3: Transpose, cast and store to output_t
if (return_columnwise_gemm_ready) { if (return_columnwise_gemm_ready) {
constexpr int c_stride = constexpr int c_stride =
...@@ -883,11 +899,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -883,11 +899,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
"num_iterations should be 1 for columnwise non-transpose case"); "num_iterations should be 1 for columnwise non-transpose case");
const int thr_idx_in_warp = threadIdx.x % kThreadsPerWarp; const int thr_idx_in_warp = threadIdx.x % kThreadsPerWarp;
const int warp_idx = threadIdx.x / kThreadsPerWarp; const int warp_idx = threadIdx.x / kThreadsPerWarp;
if(warp_idx >= kNumColBlocks) { if (warp_idx >= kNumColBlocks) {
return; // No work to do return; // No work to do
} }
const int r_s = thr_idx_in_warp * kThreadTileRow; // Row in shared memory const int r_s = thr_idx_in_warp * kThreadTileRow; // Row in shared memory
int c_s = warp_idx * num_smem_reads; // Column in shared memory int c_s = warp_idx * num_smem_reads; // Column in shared memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim64 + r_s; // Row in global memory size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim64 + r_s; // Row in global memory
const size_t c_g = const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim64 + c_s * kNVecSMem; // Column in global memory static_cast<size_t>(blockIdx.x) * kTileDim64 + c_s * kNVecSMem; // Column in global memory
...@@ -956,8 +972,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -956,8 +972,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
for (int i = 0; i < kThreadTileCol; ++i) { for (int i = 0; i < kThreadTileCol; ++i) {
if constexpr (std::is_same_v<OType, int8_t>) { if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i] = static_cast<OType>(lroundf( output_vec.data.elt[i] = static_cast<OType>(lroundf(
fmaxf(-127.0f, fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(reg_vec[row_idx].data.elt[i]) *
fminf(127.0f, static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i])))); thr_scale.data.elt[i]))));
} else { } else {
output_vec.data.elt[i] = static_cast<OType>( output_vec.data.elt[i] = static_cast<OType>(
static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]); static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]);
...@@ -979,45 +995,51 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -979,45 +995,51 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
} }
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
constexpr int kFP32SMemCol = kTileDim / kNVecSMem; constexpr size_t kThreadsPerWarp_blocklen_128 = 64;
constexpr int kFP32SMemSize = kSMemRow * kFP32SMemCol * kNVecSMem;
constexpr int kTileDim64_Rowwise = 64;
constexpr int kNVecSMem_Rowwise = 4; // The number of elements each LDS/STS touches
constexpr int kThreadsPerBlock_Rowwise = 512; // Thread block size, 8 warps in total
constexpr int kSMemRow_Rowwise = kTileDim64_Rowwise;
constexpr int kSMemCol_Rowwise = (kTileDim / kNVecSMem_Rowwise);
constexpr int kSMemSize_Rowwise = kSMemRow_Rowwise * kSMemCol_Rowwise * kNVecSMem_Rowwise;
template <bool kAligned, typename CType, typename IType, typename OType> template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel_fp32( __global__ void __launch_bounds__(kThreadsPerBlock_Rowwise)
const IType* const input, OType* const output_c, OType* const output_t, block_scaled_1d_cast_transpose_kernel_rowwise(const IType* const input, OType* const output_c,
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, CType* const tile_scales_inv_c,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t row_length, const size_t num_rows,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, const size_t scale_stride_x,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, const size_t scale_stride_y, const float epsilon,
const bool pow_2_scaling) { FP8BlockwiseRowwiseOption rowwise_option,
const bool pow_2_scaling) {
bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE; bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE;
bool return_columnwise_gemm_ready =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
bool return_columnwise_compact =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
using SMemVec = Vec<IType, kNVecSMem>; using SMemVec = Vec<IType, kNVecSMem_Rowwise>;
using OVec = Vec<OType, kNVecOut>; using OVec = Vec<OType, kNVecOut>;
union IVec { union IVec {
Vec<IType, kNVecIn> input_type; Vec<IType, kNVecIn> input_type;
Vec<SMemVec, kNVecIn / kNVecSMem> smem_type; Vec<SMemVec, kNVecIn / kNVecSMem_Rowwise> smem_type;
}; };
extern __shared__ char smem_base[]; extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(smem_base); SMemVec* smem = reinterpret_cast<SMemVec*>(smem_base);
// Step 1: Load input to shared memory // Step 1: Load input to shared memory
{ {
constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory constexpr int r_stride =
constexpr int num_iterations = kTileDim / r_stride; kThreadsPerBlock_Rowwise / kNumThreadsLoad; // stride in rows of shared memory
constexpr int num_iterations = kTileDim64_Rowwise / r_stride; //64/16=4
const int c_s = const int c_s =
(threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory (threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem_Rowwise); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory
const size_t c_g = const size_t c_g = static_cast<size_t>(blockIdx.x) * kTileDim +
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory c_s * kNVecSMem_Rowwise; // Column in global memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory size_t r_g =
static_cast<size_t>(blockIdx.y) * kTileDim64_Rowwise + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length ? min(static_cast<size_t>(kNVecIn), row_length - c_g) const size_t num_ele = c_g < row_length
: 0; // For not aligned case ? std::min(static_cast<size_t>(kNVecIn), row_length - c_g)
: 0; // For not aligned case
const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory
#pragma unroll #pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) { for (int iter = 0; iter < num_iterations; ++iter) {
...@@ -1032,13 +1054,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -1032,13 +1054,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
input_vec.input_type.clear(); input_vec.input_type.clear();
} }
} }
// Step 1.2: Write to shared memory - Column Major // Step 1.2: Write to shared memory - row Major
#pragma unroll #pragma unroll
for (int i = 0; i < kNVecIn / kNVecSMem; ++i) { for (int i = 0; i < kNVecIn / kNVecSMem_Rowwise; ++i) {
int c = c_s + i; int c = c_s + i;
int r = r_s; int r = r_s;
// Column Major Store // row Major Store
smem[c * kTileDim + r] = input_vec.smem_type.data.elt[i]; smem[r * kSMemCol_Rowwise + c] = input_vec.smem_type.data.elt[i];
} }
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case) // Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g += stride_g; input_g += stride_g;
...@@ -1054,63 +1076,62 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -1054,63 +1076,62 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
// Step 2: Cast and store to output_c // Step 2: Cast and store to output_c
if (return_rowwise) { if (return_rowwise) {
constexpr int r_stride = constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory kThreadsPerBlock_Rowwise / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride; constexpr int num_iterations = kTileDim64_Rowwise / r_stride;
const int c_s = const int c_s = (threadIdx.x % kNumThreadsStore) *
(threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory (kNVecOut / kNVecSMem_Rowwise); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory
const size_t c_g = const size_t c_g = static_cast<size_t>(blockIdx.x) * kTileDim +
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory c_s * kNVecSMem_Rowwise; // Column in global memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory size_t r_g =
static_cast<size_t>(blockIdx.y) * kTileDim64_Rowwise + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length ? min(static_cast<size_t>(kNVecOut), row_length - c_g) const size_t num_ele = c_g < row_length
: 0; // For not aligned case ? std::min(static_cast<size_t>(kNVecOut), row_length - c_g)
: 0; // For not aligned case
OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction. // the first thread to do the reduction.
const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; const unsigned src_lane =
(threadIdx.x % kThreadsPerWarp_blocklen_128) / kNumThreadsStore * kNumThreadsStore;
// This mask represents which threads should do the reduction together. // This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0;
#pragma unroll #pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) { for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut / kNVecSMem]; SMemVec smem_vec[kNVecOut / kNVecSMem_Rowwise];
// Step 2.1: Load from shared memory to registers - Column Major // Step 2.1: Load from shared memory to registers - Column Major
#pragma unroll #pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { for (int i = 0; i < kNVecOut / kNVecSMem_Rowwise; ++i) {
int c = c_s + i; int c = c_s + i;
int r = r_s; int r = r_s;
// Column Major Read // Column Major Read
smem_vec[i] = smem[c * kTileDim + r]; smem_vec[i] = smem[r * kSMemCol_Rowwise + c];
} }
// Step 2.2: Compute local amax // Step 2.2: Compute local amax
CType amax = 0; CType amax = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { for (int i = 0; i < kNVecOut / kNVecSMem_Rowwise; ++i) {
#pragma unroll #pragma unroll
for (int j = 0; j < kNVecSMem; ++j) { for (int j = 0; j < kNVecSMem_Rowwise; ++j) {
__builtin_assume(amax >= 0); __builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j])); amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
} }
} }
// Step 2.3: Reduce amax // Step 2.3: Reduce amax
#pragma unroll #pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__ //const float other_amax =__shfl_xor_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
const float other_amax = const float other_amax = __shfl_xor(amax, delta, kThreadsPerWarp_blocklen_128);
__shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#else
const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif
__builtin_assume(amax >= 0); __builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0); __builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax); amax = fmaxf(amax, other_amax);
} }
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else
amax = __shfl_sync(mask, amax, src_lane);
#endif
CType scale; CType scale;
// Step 2.4: Compute scale // Step 2.4: Compute scale
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling); scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
...@@ -1121,21 +1142,21 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -1121,21 +1142,21 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
if (write_scale_inv) { if (write_scale_inv) {
CType scale_inv = 1.0 / scale; CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; size_t row_idx = static_cast<size_t>(blockIdx.y) * kTileDim64_Rowwise + r_s; //
size_t col_idx = static_cast<size_t>(blockIdx.x); size_t col_idx = static_cast<size_t>(blockIdx.x);
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
} }
// Step 2.6: Quantize // Step 2.6: Quantize
OVec output_vec; OVec output_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { for (int i = 0; i < kNVecOut / kNVecSMem_Rowwise; ++i) {
#pragma unroll #pragma unroll
for (int j = 0; j < kNVecSMem; ++j) { for (int j = 0; j < kNVecSMem_Rowwise; ++j) {
if constexpr (std::is_same_v<OType, int8_t>) { if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i * kNVecSMem + j] = static_cast<OType>(lroundf(fmaxf( output_vec.data.elt[i * kNVecSMem_Rowwise + j] = static_cast<OType>(lroundf(fmaxf(
-127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[j]) * scale)))); -127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[j]) * scale))));
} else { } else {
output_vec.data.elt[i * kNVecSMem + j] = output_vec.data.elt[i * kNVecSMem_Rowwise + j] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale); static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale);
} }
} }
...@@ -1156,28 +1177,109 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -1156,28 +1177,109 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
} }
} }
}
constexpr int kTileDim64_Colwise = 64;
constexpr int kNVecSMem_Colwise = 2;
constexpr int kSMemRow_Colwise = kTileDim;
constexpr int kSMemCol_Colwise = (kTileDim64_Colwise / kNVecSMem_Colwise);
constexpr int kSMemSize_Colwise = kSMemRow_Colwise * kSMemCol_Colwise * kNVecSMem_Colwise;
constexpr int kNumThreadsLoad_Colwise = kTileDim64_Colwise / kNVecIn;
constexpr int kNumThreadsStore_Colwise = kTileDim / kNVecOut;
constexpr int kThreadsPerBlock_Colwise = 256;
constexpr int kNumWarps_Colwise = kThreadsPerBlock_Colwise / kThreadsPerWarp_blocklen_128;
template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock_Colwise)
block_scaled_1d_cast_transpose_kernel_colwise(
const IType* const input, OType* const output_t, CType* const tile_scales_inv_t,
const size_t row_length, const size_t num_rows, const size_t scale_t_stride_x,
const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseColumnwiseOption columnwise_option, const bool pow_2_scaling) {
bool return_columnwise_gemm_ready =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
bool return_columnwise_compact =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
using SMemVec = Vec<IType, kNVecSMem_Colwise>;
using OVec = Vec<OType, kNVecOut>;
union IVec {
Vec<IType, kNVecIn> input_type;
Vec<SMemVec, kNVecIn / kNVecSMem_Colwise> smem_type;
};
extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(smem_base);
// Step 1: Load input to shared memory
{
constexpr int r_stride =
kThreadsPerBlock_Colwise / kNumThreadsLoad_Colwise; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s = (threadIdx.x % kNumThreadsLoad_Colwise) *
(kNVecIn / kNVecSMem_Colwise); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsLoad_Colwise; // Row in shared memory
const size_t c_g = static_cast<size_t>(blockIdx.x) * kTileDim64_Colwise +
c_s * kNVecSMem_Colwise; // Column in global memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length
? std::min(static_cast<size_t>(kNVecIn), row_length - c_g)
: 0; // For not aligned case
const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
IVec input_vec;
// Step 1.1: Load from global memory (input) to registers
if constexpr (kAligned) {
input_vec.input_type.load_from(input_g);
} else {
if (r_g < num_rows) {
input_vec.input_type.load_from_elts(input_g, 0, num_ele);
} else {
input_vec.input_type.clear();
}
}
// Step 1.2: Write to shared memory - Row Major
#pragma unroll
for (int i = 0; i < kNVecIn / kNVecSMem_Colwise; ++i) {
int c = c_s + i;
int r = r_s;
// Row Major Store
smem[r * kSMemCol_Colwise + c] = input_vec.smem_type.data.elt[i];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
__syncthreads();
// Step 3: Transpose, cast and store to output_t // Step 3: Transpose, cast and store to output_t
if (return_columnwise_gemm_ready) { if (return_columnwise_gemm_ready) {
constexpr int c_stride = constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory kThreadsPerBlock_Colwise / kNumThreadsStore_Colwise; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); constexpr int num_iterations = kTileDim64_Colwise / (c_stride * kNVecSMem_Colwise);
const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory const int r_s = (threadIdx.x % kNumThreadsStore_Colwise) * kNVecOut; // Row in shared memory
int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory int c_s = threadIdx.x / kNumThreadsStore_Colwise; // Column in shared memory
size_t r_g = size_t r_g = static_cast<size_t>(blockIdx.x) * kTileDim64_Colwise +
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Row in global memory c_s * kNVecSMem_Colwise; // Row in global memory
const size_t c_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Column in global memory const size_t c_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Column in global memory
const size_t stride_g = const size_t stride_g =
static_cast<size_t>(c_stride) * kNVecSMem * num_rows; // Stride in global memory static_cast<size_t>(c_stride) * kNVecSMem_Colwise * num_rows; // Stride in global memory
const size_t num_ele = c_g < num_rows ? min(static_cast<size_t>(kNVecOut), num_rows - c_g) const size_t num_ele = c_g < num_rows ? std::min(static_cast<size_t>(kNVecOut), num_rows - c_g)
: 0; // For not aligned case : 0; // For not aligned case
OType* output_g = &output_t[r_g * num_rows + c_g]; // Output address in global memory OType* output_g = &output_t[r_g * num_rows + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction. // the first thread to do the reduction.
const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; const unsigned src_lane = (threadIdx.x % kThreadsPerWarp_blocklen_128) /
kNumThreadsStore_Colwise * kNumThreadsStore_Colwise;
// This mask represents which threads should do the reduction together. // This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; const unsigned mask = ((1 << kNumThreadsStore_Colwise) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; const bool is_src_lane = (threadIdx.x % kNumThreadsStore_Colwise) == 0;
#pragma unroll #pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) { for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut]; SMemVec smem_vec[kNVecOut];
...@@ -1186,11 +1288,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -1186,11 +1288,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kNVecOut; ++i) { for (int i = 0; i < kNVecOut; ++i) {
int r = r_s + i; int r = r_s + i;
int c = c_s; int c = c_s;
// Column Major Read // Row Major Read
smem_vec[i] = smem[c * kTileDim + r]; smem_vec[i] = smem[r * kSMemCol_Colwise + c];
} }
#pragma unroll #pragma unroll
for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) { for (int smem_idx = 0; smem_idx < kNVecSMem_Colwise; ++smem_idx) {
// Step 3.2: Compute local amax // Step 3.2: Compute local amax
CType amax = 0; CType amax = 0;
#pragma unroll #pragma unroll
...@@ -1199,22 +1301,16 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -1199,22 +1301,16 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
// Step 3.3: Reduce amax // Step 3.3: Reduce amax
#pragma unroll #pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { for (int delta = kNumThreadsStore_Colwise / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__ // const float other_amax =
const float other_amax = // __shfl_xor_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
__shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp); const float other_amax = __shfl_xor(amax, delta, kThreadsPerWarp_blocklen_128);
#else
const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif
__builtin_assume(amax >= 0); __builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0); __builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax); amax = fmaxf(amax, other_amax);
} }
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else
amax = __shfl_sync(mask, amax, src_lane);
#endif
// Step 3.4: Compute scale // Step 3.4: Compute scale
CType scale; CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling); scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
...@@ -1225,7 +1321,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -1225,7 +1321,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
if (write_scale_inv) { if (write_scale_inv) {
CType scale_inv = 1.0 / scale; CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem + smem_idx; size_t row_idx = static_cast<size_t>(blockIdx.x) * kTileDim64_Colwise +
c_s * kNVecSMem_Colwise + smem_idx;
size_t col_idx = static_cast<size_t>(blockIdx.y); size_t col_idx = static_cast<size_t>(blockIdx.y);
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
} }
...@@ -1255,34 +1352,33 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -1255,34 +1352,33 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
output_g += stride_g; output_g += stride_g;
c_s += c_stride; c_s += c_stride;
if constexpr (!kAligned) { if constexpr (!kAligned) {
r_g += c_stride * kNVecSMem; r_g += c_stride * kNVecSMem_Colwise;
} }
} }
} }
// Step 4 (return_columnwise_compact): cast in 128x1 style and store to output, skip transpose
if (return_columnwise_compact) { if (return_columnwise_compact) {
// thread tile should be 4x16, 16 means 8 smem reads // thread tile should be 4x16, 16 means 8 smem reads
constexpr int kThreadTileRow = kTileDim / kThreadsPerWarp; constexpr int kThreadTileRow = kTileDim / kThreadsPerWarp_blocklen_128;
constexpr int kThreadTileCol = kNVecOut; constexpr int kThreadTileCol = kNVecOut;
using RegVec = Vec<IType, kThreadTileCol>; using RegVec = Vec<IType, kThreadTileCol>;
using RegScaleVec = Vec<CType, kThreadTileCol>; using RegScaleVec = Vec<CType, kThreadTileCol>; //float,16
constexpr int num_smem_reads = kNVecOut / kNVecSMem; constexpr int num_smem_reads = kNVecOut / kNVecSMem_Colwise;
// c_stride will not be used here because we only have one iteration // c_stride will not be used here because we only have one iteration
// constexpr int c_stride = kThreadTileCol * kNumWarps / kNVecSMem; // constexpr int c_stride = kThreadTileCol * kNumWarps / kNVecSMem;
constexpr int num_iterations = constexpr int num_iterations =
kTileDim / (kNumWarps * kThreadTileCol); // should be only one iteration kTileDim64_Colwise / (kNumWarps_Colwise * kThreadTileCol); // should be only one iteration
static_assert(num_iterations == 1, static_assert(num_iterations == 1,
"num_iterations should be 1 for columnwise non-transpose case"); "num_iterations should be 1 for columnwise non-transpose case");
const int thr_idx_in_warp = threadIdx.x % kThreadsPerWarp; const int thr_idx_in_warp = threadIdx.x % kThreadsPerWarp_blocklen_128;
const int warp_idx = threadIdx.x / kThreadsPerWarp; const int warp_idx = threadIdx.x / kThreadsPerWarp_blocklen_128;
const int r_s = thr_idx_in_warp * kThreadTileRow; // Row in shared memory const int r_s = thr_idx_in_warp * kThreadTileRow; // Row in shared memory
int c_s = warp_idx * num_smem_reads; // Column in shared memory int c_s = warp_idx * num_smem_reads; // Column in shared memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t c_g = const size_t c_g = static_cast<size_t>(blockIdx.x) * kTileDim64_Colwise +
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory c_s * kNVecSMem_Colwise; // Column in global memory
const size_t num_ele = c_g < row_length const size_t num_ele = c_g < row_length
? min(static_cast<size_t>(kThreadTileCol), row_length - c_g) ? std::min(static_cast<size_t>(kThreadTileCol), row_length - c_g)
: 0; // For not aligned case : 0; // For not aligned case
#pragma unroll #pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) { for (int iter = 0; iter < num_iterations; ++iter) {
...@@ -1296,11 +1392,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -1296,11 +1392,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll #pragma unroll
for (int j = 0; j < num_smem_reads; ++j) { for (int j = 0; j < num_smem_reads; ++j) {
int c = c_s + j; int c = c_s + j;
SMemVec smem_vec = smem[c * kTileDim + r]; SMemVec smem_vec = smem[r * kSMemCol_Colwise + c];
// copy smem_vec to reg vec with its elements // copy smem_vec to reg vec with its elements
#pragma unroll #pragma unroll
for (int k = 0; k < kNVecSMem; ++k) { for (int k = 0; k < kNVecSMem_Colwise; ++k) {
reg_vec[i].data.elt[j * kNVecSMem + k] = smem_vec.data.elt[k]; reg_vec[i].data.elt[j * kNVecSMem_Colwise + k] = smem_vec.data.elt[k];
} }
} }
} }
...@@ -1314,13 +1410,16 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -1314,13 +1410,16 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
// Step 3.3: Reduce amax // Step 3.3: Reduce amax
const bool is_src_lane = thr_idx_in_warp == 0; const bool is_src_lane = thr_idx_in_warp == 0;
amax = warp_reduce_max<kThreadsPerWarp>(amax); #pragma unroll
constexpr int lane_zero = 0; for (int delta = kThreadsPerWarp_blocklen_128 / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__ // const float other_amax =
amax = __shfl_sync((unsigned long long)(0xFFFFFFFF), amax, lane_zero, kThreadsPerWarp); // __shfl_xor_sync((unsigned long long)(0xFFFFFFFF), amax, delta, kThreadsPerWarp);
#else const float other_amax = __shfl_xor(amax, delta, kThreadsPerWarp_blocklen_128);
amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero);
#endif __builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
// Step 3.4: Compute scale // Step 3.4: Compute scale
CType scale; CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling); scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
...@@ -1333,7 +1432,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -1333,7 +1432,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
if (write_scale_inv) { if (write_scale_inv) {
CType scale_inv = 1.0 / scale; CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.y); size_t row_idx = static_cast<size_t>(blockIdx.y);
size_t col_idx = static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem + reg_idx; size_t col_idx = static_cast<size_t>(blockIdx.x) * kTileDim64_Colwise +
c_s * kNVecSMem_Colwise + reg_idx;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
} }
} }
...@@ -1346,8 +1446,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -1346,8 +1446,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kThreadTileCol; ++i) { for (int i = 0; i < kThreadTileCol; ++i) {
if constexpr (std::is_same_v<OType, int8_t>) { if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i] = static_cast<OType>(lroundf( output_vec.data.elt[i] = static_cast<OType>(lroundf(
fmaxf(-127.0f, fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(reg_vec[row_idx].data.elt[i]) *
fminf(127.0f, static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i])))); thr_scale.data.elt[i]))));
} else { } else {
output_vec.data.elt[i] = static_cast<OType>( output_vec.data.elt[i] = static_cast<OType>(
static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]); static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]);
...@@ -1471,41 +1571,42 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1471,41 +1571,42 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
while (true) { while (true) {
if (128 == block_len) { if (128 == block_len) {
if constexpr (std::is_same_v<InputType, float>) { if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) {
size_t smem_bytes = kFP32SMemSize * sizeof(InputType); size_t smem_bytes = kSMemSize_Rowwise * sizeof(InputType);
const size_t num_blocks_x = DIVUP(row_length, (size_t)(block_len));
const size_t num_blocks_y = DIVUP(num_rows, (size_t)(block_len / 2));
dim3 grid(num_blocks_x, num_blocks_y, 1);
if (smem_bytes >= 48 * 1024) { if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute( cudaError_t err = cudaFuncSetAttribute(
(const void*)&block_scaled_1d_cast_transpose_kernel_fp32< (const void*)&block_scaled_1d_cast_transpose_kernel_rowwise<
kAligned, float, InputType, OutputType>, kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
} }
block_scaled_1d_cast_transpose_kernel_fp32<kAligned, float, InputType, block_scaled_1d_cast_transpose_kernel_rowwise<kAligned, float, InputType,
OutputType> OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>( <<<grid, kThreadsPerBlock_Rowwise, smem_bytes, stream>>>(
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<float*>(scale_inv.dptr), row_length, num_rows,
reinterpret_cast<float*>(scale_inv.dptr), scale_stride_x, scale_stride_y, epsilon, rowwise_option, pow2_scale);
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, }
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) {
epsilon, rowwise_option, columnwise_option, pow2_scale); size_t smem_bytes = kSMemSize_Colwise * sizeof(InputType);
} else { const size_t num_blocks_x = DIVUP(row_length, (size_t)(block_len / 2));
size_t smem_bytes = kSMemSize * sizeof(InputType); const size_t num_blocks_y = DIVUP(num_rows, (size_t)(block_len));
if (smem_bytes >= 48 * 1024) { dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaError_t err = cudaFuncSetAttribute( cudaError_t err = cudaFuncSetAttribute(
(const void*)&block_scaled_1d_cast_transpose_kernel< (const void*)&block_scaled_1d_cast_transpose_kernel_colwise<
kAligned, float, InputType, OutputType>, kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
} block_scaled_1d_cast_transpose_kernel_colwise<kAligned, float, InputType,
block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType> OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>( <<<grid, kThreadsPerBlock_Colwise, smem_bytes, stream>>>(
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, columnwise_option,
epsilon, rowwise_option, columnwise_option, pow2_scale); pow2_scale);
} }
break; break;
} }
...@@ -1558,8 +1659,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1558,8 +1659,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
pow2_scale); pow2_scale);
#endif #endif
) // kAligned ) // kAligned
) // OutputType ) // OutputType
) // InputType ) // InputType
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment