Commit 261e476b authored by zc20020701's avatar zc20020701 Committed by wenjh
Browse files

[DCU] Remove redundant shared memory in rowwise kernel


Signed-off-by: default avatarzhaochao <zhaochao1@sugon.com>

See merge request dcutoolkit/deeplearing/TransformerEngine!72
Co-authored-by: default avatarzhaochao <zhaochao1@sugon.com>
parent 6c9dc19d
......@@ -989,15 +989,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
#ifdef __HIP_PLATFORM_AMD__
constexpr size_t kThreadsPerWarp_blocklen_128 = 64;
constexpr int kTileDim64_Rowwise = 64;
constexpr int kNVecSMem_Rowwise = 4; // The number of elements each LDS/STS touches
constexpr int kThreadsPerBlock_Rowwise = 512; // Thread block size, 8 warps in total
constexpr int kSMemRow_Rowwise = kTileDim64_Rowwise;
constexpr int kSMemCol_Rowwise = (kTileDim / kNVecSMem_Rowwise);
constexpr int kSMemSize_Rowwise = kSMemRow_Rowwise * kSMemCol_Rowwise * kNVecSMem_Rowwise;
// Optimized Rowwise kernel: Direct register processing without shared memory
// Each warp (64 threads) processes multiple rows, 8 threads collaborate on one 128-element row
constexpr int kThreadsPerBlock_Rowwise_Opt = 512;
constexpr int kThreadsPerRow_Rowwise = 8; // 8 threads per row, each handles 16 elements = 128 total
constexpr int kRowsPerBlock_Rowwise = kThreadsPerBlock_Rowwise_Opt / kThreadsPerRow_Rowwise; // 64 rows per block
template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock_Rowwise)
__global__ void __launch_bounds__(kThreadsPerBlock_Rowwise_Opt, 4)
block_scaled_1d_cast_transpose_kernel_rowwise(const IType* const input, OType* const output_c,
CType* const tile_scales_inv_c,
const size_t row_length, const size_t num_rows,
......@@ -1008,176 +1007,98 @@ __global__ void __launch_bounds__(kThreadsPerBlock_Rowwise)
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE;
if (rowwise_option == FP8BlockwiseRowwiseOption::NONE) return;
using SMemVec = Vec<IType, kNVecSMem_Rowwise>;
using IVec = Vec<IType, kNVecOut>; // 16 elements per thread
using OVec = Vec<OType, kNVecOut>;
union IVec {
Vec<IType, kNVecIn> input_type;
Vec<SMemVec, kNVecIn / kNVecSMem_Rowwise> smem_type;
};
extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(smem_base);
// Step 1: Load input to shared memory
{
constexpr int r_stride =
kThreadsPerBlock_Rowwise / kNumThreadsLoad; // stride in rows of shared memory
constexpr int num_iterations = kTileDim64_Rowwise / r_stride; //64/16=4
const int c_s =
(threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem_Rowwise); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory
const size_t c_g = static_cast<size_t>(blockIdx.x) * kTileDim +
c_s * kNVecSMem_Rowwise; // Column in global memory
size_t r_g =
static_cast<size_t>(blockIdx.y) * kTileDim64_Rowwise + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
// Thread indexing: 8 threads per row, 64 rows per block
const int thr_in_row = threadIdx.x % kThreadsPerRow_Rowwise; // 0-7: position within row
const int row_in_block = threadIdx.x / kThreadsPerRow_Rowwise; // 0-63: which row in block
// Global position
const size_t r_g = static_cast<size_t>(blockIdx.y) * kRowsPerBlock_Rowwise + row_in_block;
const size_t c_g = static_cast<size_t>(blockIdx.x) * kTileDim + thr_in_row * kNVecOut;
// Early exit if out of bounds
if constexpr (!kAligned) {
if (r_g >= num_rows) return;
}
// Calculate number of elements for non-aligned case
const size_t num_ele = c_g < row_length
? std::min(static_cast<size_t>(kNVecIn), row_length - c_g)
: 0; // For not aligned case
const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
? (c_g + kNVecOut <= row_length ? kNVecOut : row_length - c_g)
: 0;
// Step 1: Load directly from global memory to registers (NO shared memory!)
IVec input_vec;
// Step 1.1: Load from global memory (input) to registers
const IType* input_g = &input[r_g * row_length + c_g];
if constexpr (kAligned) {
input_vec.input_type.load_from(input_g);
input_vec.load_from(input_g);
} else {
if (r_g < num_rows) {
input_vec.input_type.load_from_elts(input_g, 0, num_ele);
if (num_ele > 0) {
input_vec.load_from_elts(input_g, 0, num_ele);
} else {
input_vec.input_type.clear();
input_vec.clear();
}
}
// Step 1.2: Write to shared memory - row Major
#pragma unroll
for (int i = 0; i < kNVecIn / kNVecSMem_Rowwise; ++i) {
int c = c_s + i;
int r = r_s;
// row Major Store
smem[r * kSMemCol_Rowwise + c] = input_vec.smem_type.data.elt[i];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
__syncthreads();
// Step 2: Cast and store to output_c
if (return_rowwise) {
constexpr int r_stride =
kThreadsPerBlock_Rowwise / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim64_Rowwise / r_stride;
const int c_s = (threadIdx.x % kNumThreadsStore) *
(kNVecOut / kNVecSMem_Rowwise); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory
const size_t c_g = static_cast<size_t>(blockIdx.x) * kTileDim +
c_s * kNVecSMem_Rowwise; // Column in global memory
size_t r_g =
static_cast<size_t>(blockIdx.y) * kTileDim64_Rowwise + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length
? std::min(static_cast<size_t>(kNVecOut), row_length - c_g)
: 0; // For not aligned case
OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane =
(threadIdx.x % kThreadsPerWarp_blocklen_128) / kNumThreadsStore * kNumThreadsStore;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut / kNVecSMem_Rowwise];
// Step 2.1: Load from shared memory to registers - Column Major
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem_Rowwise; ++i) {
int c = c_s + i;
int r = r_s;
// Column Major Read
smem_vec[i] = smem[r * kSMemCol_Rowwise + c];
}
// Step 2.2: Compute local amax
// Step 2: Compute local amax (16 elements per thread)
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem_Rowwise; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem_Rowwise; ++j) {
for (int i = 0; i < kNVecOut; ++i) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
}
amax = fmaxf(amax, fabsf(static_cast<CType>(input_vec.data.elt[i])));
}
// Step 2.3: Reduce amax
// Step 3: Reduce amax across 8 threads (128 elements total)
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
//const float other_amax =__shfl_xor_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
const float other_amax = __shfl_xor(amax, delta, kThreadsPerWarp_blocklen_128);
for (int delta = kThreadsPerRow_Rowwise / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_xor(amax, delta, kThreadsPerRow_Rowwise);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
CType scale;
// Step 2.4: Compute scale
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
// Step 2.5: Write scale_inv
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g < num_rows);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.y) * kTileDim64_Rowwise + r_s; //
// Step 4: Compute scale
CType scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
// Step 5: Write scale_inv (only first thread in each row)
if (thr_in_row == 0) {
CType scale_inv = 1.0f / scale;
size_t row_idx = r_g;
size_t col_idx = static_cast<size_t>(blockIdx.x);
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}
// Step 2.6: Quantize
// Step 6: Quantize directly in registers
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem_Rowwise; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem_Rowwise; ++j) {
for (int i = 0; i < kNVecOut; ++i) {
if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i * kNVecSMem_Rowwise + j] = static_cast<OType>(lroundf(fmaxf(
-127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[j]) * scale))));
output_vec.data.elt[i] = static_cast<OType>(lroundf(
fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(input_vec.data.elt[i]) * scale))));
} else {
output_vec.data.elt[i * kNVecSMem_Rowwise + j] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale);
}
output_vec.data.elt[i] =
static_cast<OType>(static_cast<CType>(input_vec.data.elt[i]) * scale);
}
}
// Step 2.7: Store output_c
// Step 7: Store directly to global memory
OType* output_g = &output_c[r_g * row_length + c_g];
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g < num_rows) {
if (num_ele > 0) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case)
output_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
}
constexpr int kTileDim64_Colwise = 64;
constexpr int kNVecSMem_Colwise = 2;
constexpr int kSMemRow_Colwise = kTileDim;
constexpr int kSMemCol_Colwise = (kTileDim64_Colwise / kNVecSMem_Colwise);
constexpr int kSMemCol_Colwise = (kTileDim64_Colwise / kNVecSMem_Colwise) + 1; // Padding to avoid bank conflict
constexpr int kSMemSize_Colwise = kSMemRow_Colwise * kSMemCol_Colwise * kNVecSMem_Colwise;
constexpr int kNumThreadsLoad_Colwise = kTileDim64_Colwise / kNVecIn;
constexpr int kNumThreadsStore_Colwise = kTileDim / kNVecOut;
......@@ -1583,19 +1504,12 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
while (true) {
if (128 == block_len) {
if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) {
size_t smem_bytes = kSMemSize_Rowwise * sizeof(InputType);
const size_t num_blocks_x = DIVUP(row_length, (size_t)(block_len));
const size_t num_blocks_y = DIVUP(num_rows, (size_t)(block_len / 2));
if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute(
(const void*)&block_scaled_1d_cast_transpose_kernel_rowwise<
kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
}
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kRowsPerBlock_Rowwise);
dim3 grid(num_blocks_x, num_blocks_y, 1);
block_scaled_1d_cast_transpose_kernel_rowwise<kAligned, float, InputType,
OutputType>
<<<grid, kThreadsPerBlock_Rowwise, smem_bytes, stream>>>(
<<<grid, kThreadsPerBlock_Rowwise_Opt, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<float*>(scale_inv.dptr), row_length, num_rows,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment