Commit aa43da4b authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by Michael Yang
Browse files

cuda: optimize memory access

Read 4 bytes at a time (8 elements) when performing mul_mat_vec_mxfp4
parent 0ac1c0d3
......@@ -10,8 +10,8 @@ typedef union {
template <typename type_acc, int block_size> // TODO type_acc unused - consider bf16 support
static __global__ void mul_mat_vec_mxfp4(
const block_mxfp4 * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
const block_mxfp4 * __restrict__ x_base, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int64_t ncols, const int64_t nchannels_y, const int64_t stride_row,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
const int64_t row = blockIdx.x;
......@@ -23,16 +23,20 @@ static __global__ void mul_mat_vec_mxfp4(
const int64_t sample_y = sample_dst;
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int64_t ncols8 = ncols / 8;
const uint16_t dst_bias = 15;
const uint16_t dst_0p5 = 0x3800;
const uint16_t dst_m_bits = 10;
x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
// x_base is offset by blocks of 32 elements
x_base += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
// y is offset by elements
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
// dst is offset by elements
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
const float2 * y2 = (const float2 *) y;
const float4 * y4 = (const float4 *) y;
extern __shared__ char data_mmv[]; // allocated in GPU shared memory: warp_size*sizeof(float)
float * buf_iw = (float *) data_mmv;
......@@ -46,50 +50,72 @@ static __global__ void mul_mat_vec_mxfp4(
float sumf = 0.0f;
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
int offset0 = col2 / (MXFP4/2);
int i = col2 % (MXFP4/2);
const block_mxfp4 *x2 = x+offset0;
// each i8 index proceses 8 items at a time
for (int64_t i8 = tid; i8 < ncols8; i8 += block_size) {
// As i8 indexes past a block, we have to offset further
int offset0 = i8 / (MXFP4/8);
int xi = (i8 % (MXFP4/8)) * 4; // jump 4 bytes for each 8 elements
const block_mxfp4 *x = x_base+offset0;
union {
uint32_t as_bits;
float as_value;
} scale;
scale.as_bits = (((uint32_t)x2->d) << 23);
uint16_t em0 = x2->qs[i] & 0x07;
uint16_t em1 = x2->qs[i] & 0x70;
scale.as_bits = (((uint32_t)x->d) << 23);
if (isnan(scale.as_value)) {
sumf = scale.as_value;
break;
}
const uint8_t qs[4] = {
(uint8_t)(x->qs[xi]),
(uint8_t)(x->qs[xi+1]),
(uint8_t)(x->qs[xi+2]),
(uint8_t)(x->qs[xi+3])
};
const uint8_t el[8] = {
(uint8_t)(qs[0] & 0xf),
(uint8_t)((qs[0] & 0xf0) >> 4),
(uint8_t)(qs[1] & 0xf),
(uint8_t)((qs[1] & 0xf0) >> 4),
(uint8_t)(qs[2] & 0xf),
(uint8_t)((qs[2] & 0xf0) >> 4),
(uint8_t)(qs[3] & 0xf),
(uint8_t)((qs[3] & 0xf0) >> 4)
};
uint16_t em[8];
#pragma unroll
for (int i = 0; i < 8; i++) { em[i] = (uint16_t)(el[i] & 0x07); }
// float16 values
f16_t x0;
f16_t x1;
x0.u16 = (em0 << (dst_m_bits - 1)) | ((x2->qs[i] & 0x08) << 12);
x1.u16 = (em1 << (dst_m_bits - 5)) | ((x2->qs[i] & 0x80) << 8);
f16_t x4u[8];
#pragma unroll
for (int i = 0; i < 8; i++) { x4u[i].u16 = (em[i] << (dst_m_bits - 1)) | ((el[i] & 0x08) << 12); }
// Three cases:
// x is normal and non-zero: Correct bias
if ((em0 & 0x06) != 0) {
x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits);
}
if ((em1 & 0x60) != 0) {
x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits);
}
#pragma unroll
for (int i = 0; i < 8; i++) { if ((em[i] & 0x06) != 0) { x4u[i].u16 = x4u[i].u16 + ((dst_bias - 1) << dst_m_bits); } }
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
if (em0 == 0x01) {
x0.u16 = dst_0p5 | (x0.u16 & 0x8000);
}
if (em1 == 0x10) {
x1.u16 = dst_0p5 | (x1.u16 & 0x8000);
}
#pragma unroll
for (int i = 0; i < 8; i++) { if (em[i] == 0x01) { x4u[i].u16 = dst_0p5 | (x4u[i].u16 & 0x8000); } }
// x is zero, do nothing
if (isnan(scale.as_value)) {
sumf = scale.as_value;
break;
}
const float2 tmpx = {x0.f16, x1.f16};
const float2 tmpy = y2[col2];
sumf += tmpx.x*tmpy.x*scale.as_value;
sumf += tmpx.y*tmpy.y*scale.as_value;
const float scalef = scale.as_value;
const float4 tmpx0 = {x4u[0].f16, x4u[1].f16, x4u[2].f16, x4u[3].f16};
const float4 tmpx1 = {x4u[4].f16, x4u[5].f16, x4u[6].f16, x4u[7].f16};
const float4 tmpy0 = y4[i8*2];
const float4 tmpy1 = y4[i8*2+1];
sumf += tmpx0.x * tmpy0.x * scalef;
sumf += tmpx0.y * tmpy0.y * scalef;
sumf += tmpx0.z * tmpy0.z * scalef;
sumf += tmpx0.w * tmpy0.w * scalef;
sumf += tmpx1.x * tmpy1.x * scalef;
sumf += tmpx1.y * tmpy1.y * scalef;
sumf += tmpx1.z * tmpy1.z * scalef;
sumf += tmpx1.w * tmpy1.w * scalef;
}
sumf = warp_reduce_sum<warp_size>(sumf);
......@@ -151,42 +177,42 @@ static void launch_mul_mat_vec_cuda_mxfp4(
switch (block_size_best) {
case 32: {
mul_mat_vec_mxfp4<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 64: {
mul_mat_vec_mxfp4<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 96: {
mul_mat_vec_mxfp4<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 128: {
mul_mat_vec_mxfp4<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 160: {
mul_mat_vec_mxfp4<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 192: {
mul_mat_vec_mxfp4<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 224: {
mul_mat_vec_mxfp4<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 256: {
mul_mat_vec_mxfp4<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment