Unverified Commit c3fab5f7 authored by Tyler Michael Smith's avatar Tyler Michael Smith Committed by GitHub
Browse files

[Bugfix][Kernel] Prevent integer overflow in fp8 dynamic per-token quantize kernel (#9425)

parent 776dbd74
...@@ -204,8 +204,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( ...@@ -204,8 +204,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int const token_idx = blockIdx.x;
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size]; // Use int64 to avoid overflowing an int32 when calculating this offset
FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size]; int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
scalar_t const* __restrict__ token_input = &input[offset];
FP8_TYPE* __restrict__ token_output = &out[offset];
// For vectorization, token_input and token_output pointers need to be // For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively. // aligned at 8-byte and 4-byte addresses respectively.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment