Commit aa43da4b authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by Michael Yang
Browse files

cuda: optimize memory access

Read 4 bytes at a time (8 elements) when performing mul_mat_vec_mxfp4
parent 0ac1c0d3
...@@ -10,8 +10,8 @@ typedef union { ...@@ -10,8 +10,8 @@ typedef union {
template <typename type_acc, int block_size> // TODO type_acc unused - consider bf16 support template <typename type_acc, int block_size> // TODO type_acc unused - consider bf16 support
static __global__ void mul_mat_vec_mxfp4( static __global__ void mul_mat_vec_mxfp4(
const block_mxfp4 * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const block_mxfp4 * __restrict__ x_base, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row, const int64_t ncols, const int64_t nchannels_y, const int64_t stride_row,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) { const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
const int64_t row = blockIdx.x; const int64_t row = blockIdx.x;
...@@ -23,16 +23,20 @@ static __global__ void mul_mat_vec_mxfp4( ...@@ -23,16 +23,20 @@ static __global__ void mul_mat_vec_mxfp4(
const int64_t sample_y = sample_dst; const int64_t sample_y = sample_dst;
const int tid = threadIdx.x; const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int64_t ncols8 = ncols / 8;
const uint16_t dst_bias = 15; const uint16_t dst_bias = 15;
const uint16_t dst_0p5 = 0x3800; const uint16_t dst_0p5 = 0x3800;
const uint16_t dst_m_bits = 10; const uint16_t dst_m_bits = 10;
x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row; // x_base is offset by blocks of 32 elements
y += sample_y *stride_sample_y + channel_y *stride_channel_y; x_base += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
// y is offset by elements
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
// dst is offset by elements
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst; dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
const float2 * y2 = (const float2 *) y; const float4 * y4 = (const float4 *) y;
extern __shared__ char data_mmv[]; // allocated in GPU shared memory: warp_size*sizeof(float) extern __shared__ char data_mmv[]; // allocated in GPU shared memory: warp_size*sizeof(float)
float * buf_iw = (float *) data_mmv; float * buf_iw = (float *) data_mmv;
...@@ -46,50 +50,72 @@ static __global__ void mul_mat_vec_mxfp4( ...@@ -46,50 +50,72 @@ static __global__ void mul_mat_vec_mxfp4(
float sumf = 0.0f; float sumf = 0.0f;
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { // each i8 index proceses 8 items at a time
int offset0 = col2 / (MXFP4/2); for (int64_t i8 = tid; i8 < ncols8; i8 += block_size) {
int i = col2 % (MXFP4/2); // As i8 indexes past a block, we have to offset further
const block_mxfp4 *x2 = x+offset0; int offset0 = i8 / (MXFP4/8);
int xi = (i8 % (MXFP4/8)) * 4; // jump 4 bytes for each 8 elements
const block_mxfp4 *x = x_base+offset0;
union { union {
uint32_t as_bits; uint32_t as_bits;
float as_value; float as_value;
} scale; } scale;
scale.as_bits = (((uint32_t)x2->d) << 23); scale.as_bits = (((uint32_t)x->d) << 23);
uint16_t em0 = x2->qs[i] & 0x07; if (isnan(scale.as_value)) {
uint16_t em1 = x2->qs[i] & 0x70; sumf = scale.as_value;
// float16 values break;
f16_t x0; }
f16_t x1; const uint8_t qs[4] = {
x0.u16 = (em0 << (dst_m_bits - 1)) | ((x2->qs[i] & 0x08) << 12); (uint8_t)(x->qs[xi]),
x1.u16 = (em1 << (dst_m_bits - 5)) | ((x2->qs[i] & 0x80) << 8); (uint8_t)(x->qs[xi+1]),
(uint8_t)(x->qs[xi+2]),
(uint8_t)(x->qs[xi+3])
};
const uint8_t el[8] = {
(uint8_t)(qs[0] & 0xf),
(uint8_t)((qs[0] & 0xf0) >> 4),
(uint8_t)(qs[1] & 0xf),
(uint8_t)((qs[1] & 0xf0) >> 4),
(uint8_t)(qs[2] & 0xf),
(uint8_t)((qs[2] & 0xf0) >> 4),
(uint8_t)(qs[3] & 0xf),
(uint8_t)((qs[3] & 0xf0) >> 4)
};
uint16_t em[8];
#pragma unroll
for (int i = 0; i < 8; i++) { em[i] = (uint16_t)(el[i] & 0x07); }
// float16 values
f16_t x4u[8];
#pragma unroll
for (int i = 0; i < 8; i++) { x4u[i].u16 = (em[i] << (dst_m_bits - 1)) | ((el[i] & 0x08) << 12); }
// Three cases: // Three cases:
// x is normal and non-zero: Correct bias // x is normal and non-zero: Correct bias
if ((em0 & 0x06) != 0) { #pragma unroll
x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits); for (int i = 0; i < 8; i++) { if ((em[i] & 0x06) != 0) { x4u[i].u16 = x4u[i].u16 + ((dst_bias - 1) << dst_m_bits); } }
}
if ((em1 & 0x60) != 0) {
x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits);
}
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
if (em0 == 0x01) { #pragma unroll
x0.u16 = dst_0p5 | (x0.u16 & 0x8000); for (int i = 0; i < 8; i++) { if (em[i] == 0x01) { x4u[i].u16 = dst_0p5 | (x4u[i].u16 & 0x8000); } }
}
if (em1 == 0x10) {
x1.u16 = dst_0p5 | (x1.u16 & 0x8000);
}
// x is zero, do nothing // x is zero, do nothing
if (isnan(scale.as_value)) { const float scalef = scale.as_value;
sumf = scale.as_value; const float4 tmpx0 = {x4u[0].f16, x4u[1].f16, x4u[2].f16, x4u[3].f16};
break; const float4 tmpx1 = {x4u[4].f16, x4u[5].f16, x4u[6].f16, x4u[7].f16};
} const float4 tmpy0 = y4[i8*2];
const float4 tmpy1 = y4[i8*2+1];
const float2 tmpx = {x0.f16, x1.f16}; sumf += tmpx0.x * tmpy0.x * scalef;
const float2 tmpy = y2[col2]; sumf += tmpx0.y * tmpy0.y * scalef;
sumf += tmpx.x*tmpy.x*scale.as_value; sumf += tmpx0.z * tmpy0.z * scalef;
sumf += tmpx.y*tmpy.y*scale.as_value; sumf += tmpx0.w * tmpy0.w * scalef;
sumf += tmpx1.x * tmpy1.x * scalef;
sumf += tmpx1.y * tmpy1.y * scalef;
sumf += tmpx1.z * tmpy1.z * scalef;
sumf += tmpx1.w * tmpy1.w * scalef;
} }
sumf = warp_reduce_sum<warp_size>(sumf); sumf = warp_reduce_sum<warp_size>(sumf);
...@@ -151,42 +177,42 @@ static void launch_mul_mat_vec_cuda_mxfp4( ...@@ -151,42 +177,42 @@ static void launch_mul_mat_vec_cuda_mxfp4(
switch (block_size_best) { switch (block_size_best) {
case 32: { case 32: {
mul_mat_vec_mxfp4<type_acc, 32><<<block_nums, block_dims, smem, stream>>> mul_mat_vec_mxfp4<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 64: { case 64: {
mul_mat_vec_mxfp4<type_acc, 64><<<block_nums, block_dims, smem, stream>>> mul_mat_vec_mxfp4<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 96: { case 96: {
mul_mat_vec_mxfp4<type_acc, 96><<<block_nums, block_dims, smem, stream>>> mul_mat_vec_mxfp4<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 128: { case 128: {
mul_mat_vec_mxfp4<type_acc, 128><<<block_nums, block_dims, smem, stream>>> mul_mat_vec_mxfp4<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 160: { case 160: {
mul_mat_vec_mxfp4<type_acc, 160><<<block_nums, block_dims, smem, stream>>> mul_mat_vec_mxfp4<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 192: { case 192: {
mul_mat_vec_mxfp4<type_acc, 192><<<block_nums, block_dims, smem, stream>>> mul_mat_vec_mxfp4<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 224: { case 224: {
mul_mat_vec_mxfp4<type_acc, 224><<<block_nums, block_dims, smem, stream>>> mul_mat_vec_mxfp4<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 256: { case 256: {
mul_mat_vec_mxfp4<type_acc, 256><<<block_nums, block_dims, smem, stream>>> mul_mat_vec_mxfp4<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, (x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
default: { default: {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment