Commit 5a876f58 authored by zc20020701's avatar zc20020701 Committed by wenjh
Browse files

[DCU] Remove redundant shared memory in rowwise kernel


Signed-off-by: default avatarzhaochao <zhaochao1@sugon.com>

See merge request dcutoolkit/deeplearing/TransformerEngine!72
Co-authored-by: default avatarzhaochao <zhaochao1@sugon.com>
parent 00f836ef
...@@ -989,15 +989,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -989,15 +989,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
constexpr size_t kThreadsPerWarp_blocklen_128 = 64; constexpr size_t kThreadsPerWarp_blocklen_128 = 64;
constexpr int kTileDim64_Rowwise = 64; // Optimized Rowwise kernel: Direct register processing without shared memory
constexpr int kNVecSMem_Rowwise = 4; // The number of elements each LDS/STS touches // Each warp (64 threads) processes multiple rows, 8 threads collaborate on one 128-element row
constexpr int kThreadsPerBlock_Rowwise = 512; // Thread block size, 8 warps in total constexpr int kThreadsPerBlock_Rowwise_Opt = 512;
constexpr int kSMemRow_Rowwise = kTileDim64_Rowwise; constexpr int kThreadsPerRow_Rowwise = 8; // 8 threads per row, each handles 16 elements = 128 total
constexpr int kSMemCol_Rowwise = (kTileDim / kNVecSMem_Rowwise); constexpr int kRowsPerBlock_Rowwise = kThreadsPerBlock_Rowwise_Opt / kThreadsPerRow_Rowwise; // 64 rows per block
constexpr int kSMemSize_Rowwise = kSMemRow_Rowwise * kSMemCol_Rowwise * kNVecSMem_Rowwise;
template <bool kAligned, typename CType, typename IType, typename OType> template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock_Rowwise) __global__ void __launch_bounds__(kThreadsPerBlock_Rowwise_Opt, 4)
block_scaled_1d_cast_transpose_kernel_rowwise(const IType* const input, OType* const output_c, block_scaled_1d_cast_transpose_kernel_rowwise(const IType* const input, OType* const output_c,
CType* const tile_scales_inv_c, CType* const tile_scales_inv_c,
const size_t row_length, const size_t num_rows, const size_t row_length, const size_t num_rows,
...@@ -1008,168 +1007,90 @@ __global__ void __launch_bounds__(kThreadsPerBlock_Rowwise) ...@@ -1008,168 +1007,90 @@ __global__ void __launch_bounds__(kThreadsPerBlock_Rowwise)
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return; return;
} }
bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE; if (rowwise_option == FP8BlockwiseRowwiseOption::NONE) return;
using SMemVec = Vec<IType, kNVecSMem_Rowwise>; using IVec = Vec<IType, kNVecOut>; // 16 elements per thread
using OVec = Vec<OType, kNVecOut>; using OVec = Vec<OType, kNVecOut>;
union IVec {
Vec<IType, kNVecIn> input_type;
Vec<SMemVec, kNVecIn / kNVecSMem_Rowwise> smem_type;
};
extern __shared__ char smem_base[]; // Thread indexing: 8 threads per row, 64 rows per block
SMemVec* smem = reinterpret_cast<SMemVec*>(smem_base); const int thr_in_row = threadIdx.x % kThreadsPerRow_Rowwise; // 0-7: position within row
// Step 1: Load input to shared memory const int row_in_block = threadIdx.x / kThreadsPerRow_Rowwise; // 0-63: which row in block
{
constexpr int r_stride =
kThreadsPerBlock_Rowwise / kNumThreadsLoad; // stride in rows of shared memory
constexpr int num_iterations = kTileDim64_Rowwise / r_stride; //64/16=4
const int c_s =
(threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem_Rowwise); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory
const size_t c_g = static_cast<size_t>(blockIdx.x) * kTileDim +
c_s * kNVecSMem_Rowwise; // Column in global memory
size_t r_g =
static_cast<size_t>(blockIdx.y) * kTileDim64_Rowwise + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length
? std::min(static_cast<size_t>(kNVecIn), row_length - c_g)
: 0; // For not aligned case
const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
IVec input_vec;
// Step 1.1: Load from global memory (input) to registers
if constexpr (kAligned) {
input_vec.input_type.load_from(input_g);
} else {
if (r_g < num_rows) {
input_vec.input_type.load_from_elts(input_g, 0, num_ele);
} else {
input_vec.input_type.clear();
}
}
// Step 1.2: Write to shared memory - row Major
#pragma unroll
for (int i = 0; i < kNVecIn / kNVecSMem_Rowwise; ++i) {
int c = c_s + i;
int r = r_s;
// row Major Store
smem[r * kSMemCol_Rowwise + c] = input_vec.smem_type.data.elt[i];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
__syncthreads(); // Global position
const size_t r_g = static_cast<size_t>(blockIdx.y) * kRowsPerBlock_Rowwise + row_in_block;
const size_t c_g = static_cast<size_t>(blockIdx.x) * kTileDim + thr_in_row * kNVecOut;
// Step 2: Cast and store to output_c // Early exit if out of bounds
if (return_rowwise) { if constexpr (!kAligned) {
constexpr int r_stride = if (r_g >= num_rows) return;
kThreadsPerBlock_Rowwise / kNumThreadsStore; // stride in rows of shared memory }
constexpr int num_iterations = kTileDim64_Rowwise / r_stride;
const int c_s = (threadIdx.x % kNumThreadsStore) *
(kNVecOut / kNVecSMem_Rowwise); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory
const size_t c_g = static_cast<size_t>(blockIdx.x) * kTileDim +
c_s * kNVecSMem_Rowwise; // Column in global memory
size_t r_g =
static_cast<size_t>(blockIdx.y) * kTileDim64_Rowwise + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length
? std::min(static_cast<size_t>(kNVecOut), row_length - c_g)
: 0; // For not aligned case
OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane =
(threadIdx.x % kThreadsPerWarp_blocklen_128) / kNumThreadsStore * kNumThreadsStore;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0;
#pragma unroll // Calculate number of elements for non-aligned case
for (int iter = 0; iter < num_iterations; ++iter) { const size_t num_ele = c_g < row_length
SMemVec smem_vec[kNVecOut / kNVecSMem_Rowwise]; ? (c_g + kNVecOut <= row_length ? kNVecOut : row_length - c_g)
// Step 2.1: Load from shared memory to registers - Column Major : 0;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem_Rowwise; ++i) { // Step 1: Load directly from global memory to registers (NO shared memory!)
int c = c_s + i; IVec input_vec;
int r = r_s; const IType* input_g = &input[r_g * row_length + c_g];
// Column Major Read if constexpr (kAligned) {
smem_vec[i] = smem[r * kSMemCol_Rowwise + c]; input_vec.load_from(input_g);
} } else {
if (num_ele > 0) {
input_vec.load_from_elts(input_g, 0, num_ele);
} else {
input_vec.clear();
}
}
// Step 2.2: Compute local amax // Step 2: Compute local amax (16 elements per thread)
CType amax = 0; CType amax = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem_Rowwise; ++i) { for (int i = 0; i < kNVecOut; ++i) {
#pragma unroll __builtin_assume(amax >= 0);
for (int j = 0; j < kNVecSMem_Rowwise; ++j) { amax = fmaxf(amax, fabsf(static_cast<CType>(input_vec.data.elt[i])));
__builtin_assume(amax >= 0); }
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
}
}
// Step 2.3: Reduce amax // Step 3: Reduce amax across 8 threads (128 elements total)
#pragma unroll #pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { for (int delta = kThreadsPerRow_Rowwise / 2; delta > 0; delta /= 2) {
//const float other_amax =__shfl_xor_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp); const float other_amax = __shfl_xor(amax, delta, kThreadsPerRow_Rowwise);
const float other_amax = __shfl_xor(amax, delta, kThreadsPerWarp_blocklen_128); __builtin_assume(amax >= 0);
__builtin_assume(amax >= 0); __builtin_assume(other_amax >= 0);
__builtin_assume(other_amax >= 0); amax = fmaxf(amax, other_amax);
}
amax = fmaxf(amax, other_amax); // Step 4: Compute scale
} CType scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
CType scale; // Step 5: Write scale_inv (only first thread in each row)
// Step 2.4: Compute scale if (thr_in_row == 0) {
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling); CType scale_inv = 1.0f / scale;
// Step 2.5: Write scale_inv size_t row_idx = r_g;
bool write_scale_inv = is_src_lane; size_t col_idx = static_cast<size_t>(blockIdx.x);
if constexpr (!kAligned) { tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
write_scale_inv &= (r_g < num_rows); }
}
if (write_scale_inv) { // Step 6: Quantize directly in registers
CType scale_inv = 1.0 / scale; OVec output_vec;
size_t row_idx = static_cast<size_t>(blockIdx.y) * kTileDim64_Rowwise + r_s; //
size_t col_idx = static_cast<size_t>(blockIdx.x);
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}
// Step 2.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem_Rowwise; ++i) {
#pragma unroll #pragma unroll
for (int j = 0; j < kNVecSMem_Rowwise; ++j) { for (int i = 0; i < kNVecOut; ++i) {
if constexpr (std::is_same_v<OType, int8_t>) { if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i * kNVecSMem_Rowwise + j] = static_cast<OType>(lroundf(fmaxf( output_vec.data.elt[i] = static_cast<OType>(lroundf(
-127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[j]) * scale)))); fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(input_vec.data.elt[i]) * scale))));
} else { } else {
output_vec.data.elt[i * kNVecSMem_Rowwise + j] = output_vec.data.elt[i] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale); static_cast<OType>(static_cast<CType>(input_vec.data.elt[i]) * scale);
} }
} }
}
// Step 2.7: Store output_c // Step 7: Store directly to global memory
if constexpr (kAligned) { OType* output_g = &output_c[r_g * row_length + c_g];
output_vec.store_to(output_g); if constexpr (kAligned) {
} else { output_vec.store_to(output_g);
if (r_g < num_rows) { } else {
output_vec.store_to_elts(output_g, 0, num_ele); if (num_ele > 0) {
} output_vec.store_to_elts(output_g, 0, num_ele);
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case)
output_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
} }
} }
} }
...@@ -1177,7 +1098,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock_Rowwise) ...@@ -1177,7 +1098,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock_Rowwise)
constexpr int kTileDim64_Colwise = 64; constexpr int kTileDim64_Colwise = 64;
constexpr int kNVecSMem_Colwise = 2; constexpr int kNVecSMem_Colwise = 2;
constexpr int kSMemRow_Colwise = kTileDim; constexpr int kSMemRow_Colwise = kTileDim;
constexpr int kSMemCol_Colwise = (kTileDim64_Colwise / kNVecSMem_Colwise); constexpr int kSMemCol_Colwise = (kTileDim64_Colwise / kNVecSMem_Colwise) + 1; // Padding to avoid bank conflict
constexpr int kSMemSize_Colwise = kSMemRow_Colwise * kSMemCol_Colwise * kNVecSMem_Colwise; constexpr int kSMemSize_Colwise = kSMemRow_Colwise * kSMemCol_Colwise * kNVecSMem_Colwise;
constexpr int kNumThreadsLoad_Colwise = kTileDim64_Colwise / kNVecIn; constexpr int kNumThreadsLoad_Colwise = kTileDim64_Colwise / kNVecIn;
constexpr int kNumThreadsStore_Colwise = kTileDim / kNVecOut; constexpr int kNumThreadsStore_Colwise = kTileDim / kNVecOut;
...@@ -1583,19 +1504,12 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1583,19 +1504,12 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
while (true) { while (true) {
if (128 == block_len) { if (128 == block_len) {
if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) { if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) {
size_t smem_bytes = kSMemSize_Rowwise * sizeof(InputType);
const size_t num_blocks_x = DIVUP(row_length, (size_t)(block_len)); const size_t num_blocks_x = DIVUP(row_length, (size_t)(block_len));
const size_t num_blocks_y = DIVUP(num_rows, (size_t)(block_len / 2)); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kRowsPerBlock_Rowwise);
if (smem_bytes >= 48 * 1024) { dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaError_t err = cudaFuncSetAttribute(
(const void*)&block_scaled_1d_cast_transpose_kernel_rowwise<
kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
}
block_scaled_1d_cast_transpose_kernel_rowwise<kAligned, float, InputType, block_scaled_1d_cast_transpose_kernel_rowwise<kAligned, float, InputType,
OutputType> OutputType>
<<<grid, kThreadsPerBlock_Rowwise, smem_bytes, stream>>>( <<<grid, kThreadsPerBlock_Rowwise_Opt, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<float*>(scale_inv.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv.dptr), row_length, num_rows,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment