Unverified Commit a99c056b authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Fixed integer overflow issue in cast kernels (#1988)



* Fixed integer overflow when computing offsets
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 71b2dd48
......@@ -81,8 +81,8 @@ void compute_ref(const ProcessingMethod processing_method,
// Cache computations
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const int idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
......@@ -114,18 +114,18 @@ void compute_ref(const ProcessingMethod processing_method,
float block_amax = 0.0f;
for (size_t j = j_min; j < j_max; ++j) {
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const int scale_idx = i * scales_stride_rowwise + tile_X;
const size_t scale_idx = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t j = j_min; j < j_max; ++j) {
const int idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_rowwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
}
......@@ -135,18 +135,18 @@ void compute_ref(const ProcessingMethod processing_method,
float block_amax = 0.0f;
for (size_t i = i_min; i < i_max; ++i) {
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const int scale_idx = tile_Y * scales_stride_colwise + j;
const size_t scale_idx = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t i = i_min; i < i_max; ++i) {
const int idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_colwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
}
......
......@@ -64,7 +64,7 @@ void compute_ref(const IType* grad,
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
if (IS_DGATED) {
const float x = silu_elt;
......@@ -101,7 +101,7 @@ void compute_ref(const IType* grad,
float block_amax_act = 0.0f;
float block_amax_gate = 0.0f;
for (size_t j = j_min; j < j_max; ++j) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
if (IS_DGATED) {
block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
......@@ -109,18 +109,18 @@ void compute_ref(const IType* grad,
}
const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const int scale_idx_act = i * scales_stride_rowwise + tile_X;
const size_t scale_idx_act = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx_act] = biased_exponent_act;
float scale_reciprocal_gate;
if (IS_DGATED) {
const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
const int scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32;
const size_t scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32;
output_scales_rowwise[scale_idx_gate] = biased_exponent_gate;
}
for (size_t j = j_min; j < j_max; ++j) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
if (IS_DGATED) {
......@@ -139,7 +139,7 @@ void compute_ref(const IType* grad,
float block_amax_act = 0.0f;
float block_amax_gate = 0.0f;
for (size_t i = i_min; i < i_max; ++i) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
if (IS_DGATED) {
block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
......@@ -147,18 +147,18 @@ void compute_ref(const IType* grad,
}
const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const int scale_idx_act = tile_Y * scales_stride_colwise + j;
const size_t scale_idx_act = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx_act] = biased_exponent_act;
float scale_reciprocal_gate;
if (IS_DGATED) {
const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
const int scale_idx_gate = scale_idx_act + cols;
const size_t scale_idx_gate = scale_idx_act + cols;
scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
output_scales_colwise[scale_idx_gate] = biased_exponent_gate;
}
for (size_t i = i_min; i < i_max; ++i) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
if (IS_DGATED) {
......
......@@ -58,14 +58,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const float *const scale_ptr, const size_t rows, const size_t cols) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const int thread_offset_Y = tid_Y;
const int thread_offset_X = tid_X;
const size_t thread_offset_Y = tid_Y;
const size_t thread_offset_X = tid_X;
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
......@@ -131,12 +131,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
const int buff = it % BUFFERS_NUM;
const int next_it = it + 1;
const size_t buff = it % BUFFERS_NUM;
const size_t next_it = it + 1;
if (next_it < ITERATIONS) {
const int next_buff = next_it % BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
const size_t next_buff = next_it % BUFFERS_NUM;
const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DGATED) {
copy_2d_to_sharedx3(
&in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y,
......@@ -164,10 +164,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const size_t shmem_offset_x = thread_offset_X;
const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
......@@ -210,8 +210,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
// dGeLU
ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x,
......@@ -312,48 +312,48 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING);
// # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension.
constexpr int COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X);
constexpr size_t COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X);
const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X;
const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const int tid_Y_colwise = threadIdx.x / CHUNK_DIM_X;
const int tid_X_colwise = threadIdx.x % CHUNK_DIM_X;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const size_t tid_Y_colwise = threadIdx.x / CHUNK_DIM_X;
const size_t tid_X_colwise = threadIdx.x % CHUNK_DIM_X;
const int thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const size_t thread_offset_Y_colwise = tid_Y_colwise;
const size_t thread_offset_X_colwise = tid_X_colwise;
const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int col_base_rowwise = block_offset_X + thread_offset_X_rowwise;
const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const int col_base_colwise = block_offset_X + thread_offset_X_colwise;
const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const size_t col_base_rowwise = block_offset_X + thread_offset_X_rowwise;
const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols);
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const int gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X;
const int gate_scale_idx_offset_colwise = cols;
const size_t gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X;
const size_t gate_scale_idx_offset_colwise = cols;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr int SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1;
constexpr size_t SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1;
__shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X];
extern __shared__ char dynamic_shmem[];
......@@ -400,7 +400,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations
IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
......@@ -425,20 +425,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
for (int stage = 0; stage < STAGES; ++stage) {
const int buff = stage % BUFFS_NUM;
const int next_stage = stage + 1;
const int stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const int next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_DIM;
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DGATED) {
copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X,
global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act,
......@@ -459,7 +459,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::mbarrier_wait_parity(&mbar[stage], parity);
if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise =
const size_t shmem_offset_base_colwise =
buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise;
float thread_amax_act = 0.0f;
float thread_amax_gate = 0.0f;
......@@ -469,7 +469,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 1. Read/Compute elements. Find MXFP8-block AMAX
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int shmem_offset_colwise =
const size_t shmem_offset_colwise =
shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
float act_elt = static_cast<float>(in_act_sh[shmem_offset_colwise]);
......@@ -581,9 +581,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_offset_Y_colwise + stage;
const int global_scales_offset_X = scales_offset_X_colwise;
const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise;
const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows;
const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise;
......@@ -597,8 +598,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
// const int scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const int scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx_gate] = biased_exponent_gate;
}
......@@ -608,7 +609,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 3. Scale elements
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int shmem_offset_elt =
const size_t shmem_offset_elt =
shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
if constexpr (IS_DGATED) {
OType2 out_pair;
......@@ -626,7 +627,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
if constexpr (ROWWISE_SCALING) {
const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
const size_t shmem_offset_base_rowwise =
buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
float thread_amax_act = 0.0f;
float thread_amax_gate = 0.0f;
......@@ -645,9 +647,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType2 thread_amax_2x_gate = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
......@@ -695,9 +697,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in_grad;
Vec<IType, PACK_SIZE> in_act;
......@@ -765,9 +767,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const size_t stage_scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows;
const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise;
if (!out_of_bounds_rowwise) {
......@@ -783,7 +785,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
const int scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
if (!out_of_bounds_rowwise) {
scales_rowwise[scale_idx_gate] = biased_exponent_gate;
}
......@@ -826,9 +828,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate);
}
}
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) {
out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]);
......@@ -843,9 +845,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM;
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
......
......@@ -80,33 +80,33 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING;
const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const int tid_Y_rowwise = threadIdx.x / THREADS_X;
const int tid_X_rowwise = threadIdx.x % THREADS_X;
const int tid_Y_colwise = 0;
const int tid_X_colwise = threadIdx.x;
const int thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise;
const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const int col_base_colwise = block_offset_X + thread_offset_X_colwise;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X;
const size_t tid_Y_colwise = 0;
const size_t tid_X_colwise = threadIdx.x;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const size_t thread_offset_Y_colwise = tid_Y_colwise;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
......@@ -139,7 +139,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
OType *out_colwise_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
......@@ -173,20 +173,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
for (int stage = 0; stage < STAGES; ++stage) {
const int buff = stage % BUFFS_NUM;
const int next_stage = stage + 1;
const int stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const int next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_DIM;
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input,
......@@ -205,7 +205,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float thread_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise;
const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise;
thread_amax = 0.0f;
float in_compute_colwise[BUFF_DIM_Y];
IType in_colwise_IType[BUFF_DIM_Y];
......@@ -215,7 +215,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType thread_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < BUFF_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i]));
}
......@@ -223,7 +223,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} else {
#pragma unroll
for (int i = 0; i < BUFF_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (IS_ACT) {
......@@ -263,9 +263,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_offset_Y_colwise + stage;
const int global_scales_offset_X = scales_offset_X_colwise;
const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise;
const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
......@@ -282,13 +283,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
const float scaled_out = in * block_scale_inverse;
const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X;
const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X;
out_colwise_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
}
}
if constexpr (ROWWISE_SCALING) {
const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
const size_t shmem_offset_base_rowwise =
buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
thread_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES];
......@@ -301,9 +303,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
......@@ -319,9 +321,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
......@@ -354,9 +356,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
......@@ -406,9 +408,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const size_t stage_scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
......@@ -434,9 +436,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x);
}
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out.store_to(&out_rowwise_sh[shmem_offset_rowwise]);
}
}
......@@ -452,9 +454,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM;
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
......@@ -485,18 +487,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Added extra 1-element padding per thread_X to reduce bank conflicts
float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem);
constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
const int shmem_thread_offset =
const size_t shmem_thread_offset =
tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1);
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
const int shmem_elt_idx = swizzled_group_offset + e;
const size_t shmem_elt_idx = swizzled_group_offset + e;
partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j];
}
}
......@@ -504,15 +506,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
for (int i = 0; i < THREADS_Y; ++i) {
// Add extra element offset per MXFP8 scaling block [1x32]
const int scaling_block = threadIdx.x / SCALE_DIM_X;
const size_t scaling_block = threadIdx.x / SCALE_DIM_X;
thread_partial_dbias +=
partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block];
}
}
const int dbias_stride = cols;
const int dbias_offset_Y = blockIdx.y;
const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x;
const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const size_t dbias_stride = cols;
const size_t dbias_offset_Y = blockIdx.y;
const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x;
const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols);
if (!col_out_of_bounds_dbias) {
dbias_workspace[dbias_idx] = thread_partial_dbias;
......@@ -561,19 +563,19 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
const size_t cols) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X;
const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X;
const int tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK;
const int tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK;
const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK;
const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK;
const int thread_offset_Y = tid_Y;
const int thread_offset_X = tid_X;
const size_t thread_offset_Y = tid_Y;
const size_t thread_offset_X = tid_X;
const int dbias_offset_Y = blockIdx.y + tid_Y;
const int my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X;
const size_t dbias_offset_Y = blockIdx.y + tid_Y;
const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X;
const bool col_out_of_bounds = my_column >= cols;
const int dbias_stride = cols;
const size_t dbias_stride = cols;
float partial_dbias = 0.f;
......@@ -588,7 +590,7 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
__shared__ alignas(TMA_SHMEM_ALIGNMENT)
OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM;
constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
......@@ -600,13 +602,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
int parity = 0;
const int chunk_offset_Y = block_offset_Y;
const int chunk_offset_X = block_offset_X;
const size_t chunk_offset_Y = block_offset_Y;
const size_t chunk_offset_X = block_offset_X;
#pragma unroll
for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) {
const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y;
const int chunk_stage_offset_X = chunk_offset_X;
const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y;
const size_t chunk_stage_offset_X = chunk_offset_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input,
......@@ -621,13 +623,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
#pragma unroll
for (int iter = 0; iter < FP8_ITERATIONS; ++iter) {
const int buff = iter % FP8_BUFFERS_NUM;
const int next_iter = iter + FP8_PREFETCH_BUFFERS_NUM;
const size_t buff = iter % FP8_BUFFERS_NUM;
const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM;
const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y;
if (next_iter < FP8_ITERATIONS) {
const int next_buff = next_iter % FP8_BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
const size_t next_buff = next_iter % FP8_BUFFERS_NUM;
const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input,
......@@ -644,9 +646,9 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
#pragma unroll
for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X;
const size_t stage_offset_Y = stage;
const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const size_t shmem_offset_x = thread_offset_X;
const size_t row = row_base + shmem_offset_y;
const bool row_out_of_bounds = row >= rows;
const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds;
......@@ -685,8 +687,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), chunk_it_offset_x,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_sh[buff]));
......@@ -704,8 +706,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
parity ^= 1;
if constexpr (IS_DBIAS) {
const int dbias_offset_X = my_column;
const int dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X;
const size_t dbias_offset_X = my_column;
const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X;
if (!col_out_of_bounds) {
dbias_workspace[dbias_offset] = partial_dbias;
}
......@@ -747,7 +749,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int block_offset = blockIdx.x * ELEMS_PER_BLOCK;
const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK;
const IType *input = input_ptr + block_offset;
OType *output = output_ptr + block_offset;
......@@ -758,8 +760,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
__shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM];
__shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM];
constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS;
constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS;
constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS;
constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS;
const bool is_master_thread = (threadIdx.x == 0);
......@@ -775,12 +777,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
#pragma unroll
for (int iter = 0; iter < ITERATIONS; ++iter) {
const int buff = iter % SHMEM_BUFFERS;
const int it_offset = iter * SHMEM_DIM;
const size_t buff = iter % SHMEM_BUFFERS;
const size_t it_offset = iter * SHMEM_DIM;
const int next_iter = iter + 1;
const int next_buff = next_iter % SHMEM_BUFFERS;
const int next_iter_offset = next_iter * SHMEM_DIM;
const size_t next_iter = iter + 1;
const size_t next_buff = next_iter % SHMEM_BUFFERS;
const size_t next_iter_offset = next_iter * SHMEM_DIM;
if (next_iter < ITERATIONS) {
copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN,
......@@ -794,7 +796,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
#pragma unroll
for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) {
const int shmem_offset = chunk * CHUNK_SIZE + threadIdx.x;
const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x;
float elt = static_cast<float>(in_sh[buff][shmem_offset]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
......@@ -847,12 +849,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
constexpr size_t DBIAS_THREADS_PER_BLOCK = 256;
template <int nvec, typename OType>
__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK)
reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, const int rows,
const int cols) {
reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial,
const size_t rows, const size_t cols) {
using ComputeVec = Vec<float, nvec>;
using OutputVec = Vec<OType, nvec>;
const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id * nvec >= cols) {
return;
......@@ -883,8 +885,8 @@ __global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK)
template <typename IType>
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
cudaStream_t stream) {
constexpr int reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType);
constexpr size_t reduce_dbias_store_bytes = 8; // stg.64
constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType);
NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec);
......@@ -1244,8 +1246,8 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) {
bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim();
constexpr int TMA_bytes = 16;
const int alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
constexpr size_t TMA_bytes = 16;
const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment