Unverified Commit a99c056b authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Fixed integer overflow issue in cast kernels (#1988)



* Fixed integer overflow when computing offsets
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 71b2dd48
......@@ -81,8 +81,8 @@ void compute_ref(const ProcessingMethod processing_method,
// Cache computations
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const int idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
......@@ -114,18 +114,18 @@ void compute_ref(const ProcessingMethod processing_method,
float block_amax = 0.0f;
for (size_t j = j_min; j < j_max; ++j) {
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const int scale_idx = i * scales_stride_rowwise + tile_X;
const size_t scale_idx = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t j = j_min; j < j_max; ++j) {
const int idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_rowwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
}
......@@ -135,18 +135,18 @@ void compute_ref(const ProcessingMethod processing_method,
float block_amax = 0.0f;
for (size_t i = i_min; i < i_max; ++i) {
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const int scale_idx = tile_Y * scales_stride_colwise + j;
const size_t scale_idx = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t i = i_min; i < i_max; ++i) {
const int idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_colwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
}
......
......@@ -64,7 +64,7 @@ void compute_ref(const IType* grad,
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
if (IS_DGATED) {
const float x = silu_elt;
......@@ -101,7 +101,7 @@ void compute_ref(const IType* grad,
float block_amax_act = 0.0f;
float block_amax_gate = 0.0f;
for (size_t j = j_min; j < j_max; ++j) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
if (IS_DGATED) {
block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
......@@ -109,18 +109,18 @@ void compute_ref(const IType* grad,
}
const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const int scale_idx_act = i * scales_stride_rowwise + tile_X;
const size_t scale_idx_act = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx_act] = biased_exponent_act;
float scale_reciprocal_gate;
if (IS_DGATED) {
const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
const int scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32;
const size_t scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32;
output_scales_rowwise[scale_idx_gate] = biased_exponent_gate;
}
for (size_t j = j_min; j < j_max; ++j) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
if (IS_DGATED) {
......@@ -139,7 +139,7 @@ void compute_ref(const IType* grad,
float block_amax_act = 0.0f;
float block_amax_gate = 0.0f;
for (size_t i = i_min; i < i_max; ++i) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
if (IS_DGATED) {
block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
......@@ -147,18 +147,18 @@ void compute_ref(const IType* grad,
}
const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const int scale_idx_act = tile_Y * scales_stride_colwise + j;
const size_t scale_idx_act = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx_act] = biased_exponent_act;
float scale_reciprocal_gate;
if (IS_DGATED) {
const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
const int scale_idx_gate = scale_idx_act + cols;
const size_t scale_idx_gate = scale_idx_act + cols;
scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
output_scales_colwise[scale_idx_gate] = biased_exponent_gate;
}
for (size_t i = i_min; i < i_max; ++i) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
if (IS_DGATED) {
......
......@@ -58,14 +58,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const float *const scale_ptr, const size_t rows, const size_t cols) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const int thread_offset_Y = tid_Y;
const int thread_offset_X = tid_X;
const size_t thread_offset_Y = tid_Y;
const size_t thread_offset_X = tid_X;
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
......@@ -131,12 +131,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
const int buff = it % BUFFERS_NUM;
const int next_it = it + 1;
const size_t buff = it % BUFFERS_NUM;
const size_t next_it = it + 1;
if (next_it < ITERATIONS) {
const int next_buff = next_it % BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
const size_t next_buff = next_it % BUFFERS_NUM;
const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DGATED) {
copy_2d_to_sharedx3(
&in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y,
......@@ -164,10 +164,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const size_t shmem_offset_x = thread_offset_X;
const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
......@@ -210,8 +210,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
// dGeLU
ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x,
......@@ -312,48 +312,48 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING);
// # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension.
constexpr int COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X);
constexpr size_t COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X);
const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X;
const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const int tid_Y_colwise = threadIdx.x / CHUNK_DIM_X;
const int tid_X_colwise = threadIdx.x % CHUNK_DIM_X;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const size_t tid_Y_colwise = threadIdx.x / CHUNK_DIM_X;
const size_t tid_X_colwise = threadIdx.x % CHUNK_DIM_X;
const int thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const size_t thread_offset_Y_colwise = tid_Y_colwise;
const size_t thread_offset_X_colwise = tid_X_colwise;
const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int col_base_rowwise = block_offset_X + thread_offset_X_rowwise;
const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const int col_base_colwise = block_offset_X + thread_offset_X_colwise;
const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const size_t col_base_rowwise = block_offset_X + thread_offset_X_rowwise;
const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols);
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const int gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X;
const int gate_scale_idx_offset_colwise = cols;
const size_t gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X;
const size_t gate_scale_idx_offset_colwise = cols;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr int SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1;
constexpr size_t SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1;
__shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X];
extern __shared__ char dynamic_shmem[];
......@@ -400,7 +400,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations
IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
......@@ -425,20 +425,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
for (int stage = 0; stage < STAGES; ++stage) {
const int buff = stage % BUFFS_NUM;
const int next_stage = stage + 1;
const int stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const int next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_DIM;
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DGATED) {
copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X,
global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act,
......@@ -459,7 +459,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::mbarrier_wait_parity(&mbar[stage], parity);
if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise =
const size_t shmem_offset_base_colwise =
buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise;
float thread_amax_act = 0.0f;
float thread_amax_gate = 0.0f;
......@@ -469,7 +469,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 1. Read/Compute elements. Find MXFP8-block AMAX
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int shmem_offset_colwise =
const size_t shmem_offset_colwise =
shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
float act_elt = static_cast<float>(in_act_sh[shmem_offset_colwise]);
......@@ -581,9 +581,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_offset_Y_colwise + stage;
const int global_scales_offset_X = scales_offset_X_colwise;
const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise;
const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows;
const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise;
......@@ -597,8 +598,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
// const int scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const int scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx_gate] = biased_exponent_gate;
}
......@@ -608,7 +609,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 3. Scale elements
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int shmem_offset_elt =
const size_t shmem_offset_elt =
shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
if constexpr (IS_DGATED) {
OType2 out_pair;
......@@ -626,7 +627,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
if constexpr (ROWWISE_SCALING) {
const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
const size_t shmem_offset_base_rowwise =
buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
float thread_amax_act = 0.0f;
float thread_amax_gate = 0.0f;
......@@ -645,9 +647,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType2 thread_amax_2x_gate = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
......@@ -695,9 +697,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in_grad;
Vec<IType, PACK_SIZE> in_act;
......@@ -765,9 +767,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const size_t stage_scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows;
const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise;
if (!out_of_bounds_rowwise) {
......@@ -783,7 +785,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
const int scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
if (!out_of_bounds_rowwise) {
scales_rowwise[scale_idx_gate] = biased_exponent_gate;
}
......@@ -826,9 +828,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate);
}
}
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) {
out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]);
......@@ -843,9 +845,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM;
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment