Commit 4fb47ed3 authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by Michael Yang
Browse files

MXFP4 support

This implements the Open Compute Microscaling (MX) FP4 format
as a tensor type with backend implementations focusing
on mulmat and mulmatid on CPU, CUDA, and Metal.
parent 9194874d
...@@ -191,8 +191,8 @@ const ( ...@@ -191,8 +191,8 @@ const (
TensorTypeF16 TensorTypeF16
TensorTypeQ4_0 TensorTypeQ4_0
TensorTypeQ4_1 TensorTypeQ4_1
tensorTypeQ4_2 // unused by GGML TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
tensorTypeQ4_3 // unused by GGML tensorTypeQ4_3 // unused by GGML
TensorTypeQ5_0 TensorTypeQ5_0
TensorTypeQ5_1 TensorTypeQ5_1
TensorTypeQ8_0 TensorTypeQ8_0
...@@ -264,6 +264,8 @@ func ParseTensorType(s string) (TensorType, error) { ...@@ -264,6 +264,8 @@ func ParseTensorType(s string) (TensorType, error) {
return TensorTypeF64, nil return TensorTypeF64, nil
case "BF16": case "BF16":
return TensorTypeBF16, nil return TensorTypeBF16, nil
case "MXFP4":
return TensorTypeMXFP4, nil
default: default:
return 0, fmt.Errorf("unsupported quantization type %s", s) return 0, fmt.Errorf("unsupported quantization type %s", s)
} }
...@@ -316,6 +318,8 @@ func (t TensorType) String() string { ...@@ -316,6 +318,8 @@ func (t TensorType) String() string {
return "F64" return "F64"
case TensorTypeBF16: case TensorTypeBF16:
return "BF16" return "BF16"
case TensorTypeMXFP4:
return "MXFP4"
default: default:
return "unknown" return "unknown"
} }
......
This diff is collapsed.
...@@ -469,4 +469,5 @@ const ( ...@@ -469,4 +469,5 @@ const (
DTypeQ80 DTypeQ80
DTypeQ40 DTypeQ40
DTypeI32 DTypeI32
DTypeMXFP4
) )
...@@ -708,6 +708,8 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor { ...@@ -708,6 +708,8 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
cdtype = C.GGML_TYPE_Q4_0 cdtype = C.GGML_TYPE_Q4_0
case ml.DTypeI32: case ml.DTypeI32:
cdtype = C.GGML_TYPE_I32 cdtype = C.GGML_TYPE_I32
case ml.DTypeMXFP4:
cdtype = C.GGML_TYPE_MXFP4
default: default:
panic("unsupported dtype") panic("unsupported dtype")
} }
...@@ -896,6 +898,8 @@ func (t *Tensor) DType() ml.DType { ...@@ -896,6 +898,8 @@ func (t *Tensor) DType() ml.DType {
return ml.DTypeQ40 return ml.DTypeQ40
case C.GGML_TYPE_I32: case C.GGML_TYPE_I32:
return ml.DTypeI32 return ml.DTypeI32
case C.GGML_TYPE_MXFP4:
return ml.DTypeMXFP4
default: default:
return ml.DTypeOther return ml.DTypeOther
} }
......
...@@ -353,7 +353,7 @@ extern "C" { ...@@ -353,7 +353,7 @@ extern "C" {
GGML_TYPE_F16 = 1, GGML_TYPE_F16 = 1,
GGML_TYPE_Q4_0 = 2, GGML_TYPE_Q4_0 = 2,
GGML_TYPE_Q4_1 = 3, GGML_TYPE_Q4_1 = 3,
// GGML_TYPE_Q4_2 = 4, support has been removed GGML_TYPE_MXFP4 = 4, // Formerly removed type GGML_TYPE_Q4_2
// GGML_TYPE_Q4_3 = 5, support has been removed // GGML_TYPE_Q4_3 = 5, support has been removed
GGML_TYPE_Q5_0 = 6, GGML_TYPE_Q5_0 = 6,
GGML_TYPE_Q5_1 = 7, GGML_TYPE_Q5_1 = 7,
......
...@@ -417,6 +417,13 @@ typedef struct { ...@@ -417,6 +417,13 @@ typedef struct {
} block_iq4_xs; } block_iq4_xs;
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
#define MXFP4 32
typedef struct {
uint8_t d; // scale E8M0 float
uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float
} block_mxfp4;
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding");
#endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL
#endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL
......
...@@ -58,6 +58,8 @@ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const ...@@ -58,6 +58,8 @@ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
...@@ -362,6 +362,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { ...@@ -362,6 +362,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
}, },
[GGML_TYPE_MXFP4] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_mxfp4,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
}; };
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
......
...@@ -4965,6 +4965,7 @@ void ggml_compute_forward_clamp( ...@@ -4965,6 +4965,7 @@ void ggml_compute_forward_clamp(
case GGML_TYPE_I32: case GGML_TYPE_I32:
case GGML_TYPE_I64: case GGML_TYPE_I64:
case GGML_TYPE_F64: case GGML_TYPE_F64:
case GGML_TYPE_MXFP4:
case GGML_TYPE_COUNT: case GGML_TYPE_COUNT:
{ {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
......
...@@ -250,3 +250,93 @@ ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, fl ...@@ -250,3 +250,93 @@ ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, fl
} }
return sum = (ggml_float)logf(sum); return sum = (ggml_float)logf(sum);
} }
#define MXFP4 32
typedef struct {
uint8_t d; // scale E8M0 float
uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float
} block_mxfp4;
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding");
#define MXFP4_VALS {0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0}
void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) {
assert(nrc == 1);
GGML_UNUSED(nrc);
GGML_UNUSED(bx);
GGML_UNUSED(by);
GGML_UNUSED(bs);
ggml_float mxfp4_table[] = MXFP4_VALS;
#if defined(GGML_SIMD)
float sumf = 0.0f;
const int np = (n & ~(GGML_F32_STEP - 1));
const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx;
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
GGML_F32_VEC scalev;
GGML_F32_VEC ax[GGML_F32_ARR];
GGML_F32_VEC ay[GGML_F32_ARR];
for (int i = 0; i < np; i += GGML_F32_STEP) { // ARM: +16 AVX512: +64
for (int j = 0; j < GGML_F32_ARR; j++) { // ARM: 0 .. 4 AVX512: 0 .. 4
// convert GGML_F32_ARR X elements
const int ib = (i + j*GGML_F32_EPR) / MXFP4;
const block_mxfp4 * GGML_RESTRICT x = &xx[ib];
union {
uint32_t as_bits;
float as_value;
} scale;
scale.as_bits = (((uint32_t)x->d) << 23);
scalev = GGML_F32_VEC_SET1(scale.as_value);
float xf[GGML_F32_EPR]= {0.f};
assert(((i+j*GGML_F32_EPR) % MXFP4)+GGML_F32_ARR < MXFP4 && "block overrun");
for (int qi = 0; qi < GGML_F32_EPR/2 ; ++qi) {
xf[qi*2] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf)];
xf[qi*2+1] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf0) >> 4];
}
ax[j] = GGML_F32_VEC_MUL(GGML_F32_VEC_LOAD(xf), scalev);
ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
}
}
GGML_F32_VEC_REDUCE(sumf, sum);
// leftovers
for (int i = np; i < n; i+=2) {
const int ib = i / MXFP4;
const block_mxfp4 * GGML_RESTRICT x = &xx[ib];
union {
uint32_t as_bits;
float as_value;
} scale;
scale.as_bits = (((uint32_t)x->d) << 23);
sumf += y[i] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf)];
sumf += y[i+1] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf0) >> 4];
}
#else // defined(GGML_SIMD)
const int nb = n / MXFP4;
assert(n % MXFP4 == 0);
int yi = 0;
const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx;
ggml_float sumf = 0.0;
for (int ib = 0; ib < nb; ++ib) {
const block_mxfp4 * GGML_RESTRICT x = &xx[ib + 0];
union {
uint32_t as_bits;
float as_value;
} scale;
scale.as_bits = (((uint32_t)x->d) << 23);
for (int i = 0; i < MXFP4/2; ++i) {
sumf += mxfp4_table[(x->qs[i] & 0xf)] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2]);
sumf += mxfp4_table[(x->qs[i] & 0xf0) >> 4] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2+1]);
}
}
#endif
*s = sumf;
}
...@@ -42,6 +42,8 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G ...@@ -42,6 +42,8 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc); void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc); void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
void ggml_vec_silu_f32(const int n, float * y, const float * x); void ggml_vec_silu_f32(const int n, float * y, const float * x);
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max); ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max); ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
......
...@@ -571,6 +571,82 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t ...@@ -571,6 +571,82 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y); dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
} }
// MXFP4 dequantize derived from dequantize_block_q4_0
template<typename dst_t>
static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
const uint16_t dst_bias = 15;
const uint16_t dst_0p5 = 0x3800;
const uint16_t dst_m_bits = 10;
const int64_t i = blockIdx.x;
// assume 32 threads
const int64_t tid = threadIdx.x;
const int64_t il = tid/8;
const int64_t ir = tid%8;
const int64_t ib = 8*i + ir;
if (ib >= nb32) {
return;
}
const uint64_t offset = 256*i + MXFP4*ir + 8*il;
dst_t * y = yy + offset;
const block_mxfp4 * x = (const block_mxfp4 *)vx + ib;
union {
uint32_t as_bits;
float as_value;
} scale;
scale.as_bits = (((uint32_t)x->d) << 23);
// offset within the block 1/4 chunks (8 items)
const uint8_t * q = x->qs + 4*il;
for (int l = 0; l < 4; ++l) {
uint16_t em0 = q[l] & 0x07;
uint16_t em1 = q[l] & 0x70;
// float16 values
iq1m_scale_t x0;
iq1m_scale_t x1;
x0.u16 = (em0 << (dst_m_bits - 1)) | ((q[l] & 0x08) << 12);
x1.u16 = (em1 << (dst_m_bits - 5)) | ((q[l] & 0x80) << 8);
// Three cases:
// x is normal and non-zero: Correct bias
if ((em0 & 0x06) != 0) {
x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits);
}
if ((em1 & 0x60) != 0) {
x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits);
}
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
if (em0 == 0x01) {
x0.u16 = dst_0p5 | (x0.u16 & 0x8000);
}
if (em1 == 0x10) {
x1.u16 = dst_0p5 | (x1.u16 & 0x8000);
}
// x is zero, do nothing
// XXX it looks correct here - but mulmat still gives bad results...
// printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n",
// i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 0, scale * float(x0.f16));
// printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n",
// i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 1, scale * float(x1.f16));
y[l*2] = scale.as_value * float(x0.f16);
y[l*2+1] = scale.as_value * float(x1.f16);
}
}
// derived from dequantize_row_q4_0_cuda
template<typename dst_t>
static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y, nb32);
}
template <typename src_t, typename dst_t> template <typename src_t, typename dst_t>
static __global__ void convert_unary( static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
...@@ -664,6 +740,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { ...@@ -664,6 +740,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return convert_unary_cont_cuda<float>; return convert_unary_cont_cuda<float>;
case GGML_TYPE_BF16: case GGML_TYPE_BF16:
return convert_unary_cont_cuda<nv_bfloat16>; return convert_unary_cont_cuda<nv_bfloat16>;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
default: default:
return nullptr; return nullptr;
} }
...@@ -713,6 +791,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { ...@@ -713,6 +791,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return convert_unary_cont_cuda<half>; return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16: case GGML_TYPE_BF16:
return convert_unary_cont_cuda<nv_bfloat16>; return convert_unary_cont_cuda<nv_bfloat16>;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
default: default:
return nullptr; return nullptr;
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ggml-cuda/im2col.cuh" #include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmq.cuh" #include "ggml-cuda/mmq.cuh"
#include "ggml-cuda/mmv.cuh" #include "ggml-cuda/mmv.cuh"
#include "ggml-cuda/mmvmxfp4.cuh"
#include "ggml-cuda/mmvq.cuh" #include "ggml-cuda/mmvq.cuh"
#include "ggml-cuda/norm.cuh" #include "ggml-cuda/norm.cuh"
#include "ggml-cuda/opt-step-adamw.cuh" #include "ggml-cuda/opt-step-adamw.cuh"
...@@ -1202,7 +1203,7 @@ static void ggml_cuda_op_mul_mat_cublas( ...@@ -1202,7 +1203,7 @@ static void ggml_cuda_op_mul_mat_cublas(
const int cc = ggml_cuda_info().devices[id].cc; const int cc = ggml_cuda_info().devices[id].cc;
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT && src0->type != GGML_TYPE_MXFP4;
if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id)); ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
...@@ -1924,7 +1925,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor ...@@ -1924,7 +1925,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1; && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE
&& src0->type != GGML_TYPE_MXFP4;
bool use_mul_mat_vec_mxfp4 = src0->type == GGML_TYPE_MXFP4
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
...@@ -1978,6 +1983,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor ...@@ -1978,6 +1983,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
} else if (use_mul_mat_q) { } else if (use_mul_mat_q) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
} else if (use_mul_mat_vec_mxfp4) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_mxfp4, nullptr);
} else { } else {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
} }
...@@ -1997,6 +2004,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * ...@@ -1997,6 +2004,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (ne2 == 1 && src0->type == GGML_TYPE_MXFP4) {
ggml_cuda_mul_mat_vec_mxfp4(ctx, src0, src1, ids, dst);
return;
}
if (ne2 == 1) { if (ne2 == 1) {
if (ggml_is_quantized(src0->type)) { if (ggml_is_quantized(src0->type)) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
...@@ -3056,6 +3067,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g ...@@ -3056,6 +3067,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_BF16: case GGML_TYPE_BF16:
case GGML_TYPE_MXFP4:
#ifdef GGML_USE_MUSA #ifdef GGML_USE_MUSA
if (a->type == GGML_TYPE_Q3_K) { if (a->type == GGML_TYPE_Q3_K) {
return false; return false;
......
#include "ggml.h"
#include "common.cuh"
#include "mmvmxfp4.cuh"
// MXFP4 implementation derived from mmv.cu float32 code paths
typedef union {
half f16;
uint16_t u16;
} f16_t;
template <typename type_acc, int block_size> // TODO type_acc unused - consider bf16 support
static __global__ void mul_mat_vec_mxfp4(
const block_mxfp4 * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
const int64_t row = blockIdx.x;
const int64_t channel_dst = blockIdx.y;
const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst;
const int64_t sample_dst = blockIdx.z;
const int64_t sample_x = sample_dst / sample_ratio;
const int64_t sample_y = sample_dst;
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const uint16_t dst_bias = 15;
const uint16_t dst_0p5 = 0x3800;
const uint16_t dst_m_bits = 10;
x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[]; // allocated in GPU shared memory: warp_size*sizeof(float)
float * buf_iw = (float *) data_mmv;
if (block_size > warp_size) {
if (tid < warp_size) {
buf_iw[tid] = 0.0f;
}
__syncthreads();
}
float sumf = 0.0f;
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
int offset0 = col2 / (MXFP4/2);
int i = col2 % (MXFP4/2);
const block_mxfp4 *x2 = x+offset0;
union {
uint32_t as_bits;
float as_value;
} scale;
scale.as_bits = (((uint32_t)x2->d) << 23);
uint16_t em0 = x2->qs[i] & 0x07;
uint16_t em1 = x2->qs[i] & 0x70;
// float16 values
f16_t x0;
f16_t x1;
x0.u16 = (em0 << (dst_m_bits - 1)) | ((x2->qs[i] & 0x08) << 12);
x1.u16 = (em1 << (dst_m_bits - 5)) | ((x2->qs[i] & 0x80) << 8);
// Three cases:
// x is normal and non-zero: Correct bias
if ((em0 & 0x06) != 0) {
x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits);
}
if ((em1 & 0x60) != 0) {
x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits);
}
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
if (em0 == 0x01) {
x0.u16 = dst_0p5 | (x0.u16 & 0x8000);
}
if (em1 == 0x10) {
x1.u16 = dst_0p5 | (x1.u16 & 0x8000);
}
// x is zero, do nothing
if (isnan(scale.as_value)) {
sumf = scale.as_value;
break;
}
const float2 tmpx = {x0.f16, x1.f16};
const float2 tmpy = y2[col2];
sumf += tmpx.x*tmpy.x*scale.as_value;
sumf += tmpx.y*tmpy.y*scale.as_value;
}
sumf = warp_reduce_sum<warp_size>(sumf);
if (block_size > warp_size) {
buf_iw[tid/warp_size] = sumf;
__syncthreads();
if (tid >= warp_size) {
return;
}
sumf = buf_iw[tid];
sumf = warp_reduce_sum<warp_size>(sumf);
}
if (tid != 0) {
return;
}
dst[row] = sumf;
}
template <typename type_acc>
static void launch_mul_mat_vec_cuda_mxfp4(
const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
// GGML_ASSERT(stride_row % 2 == 0); // TODO
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
int device;
int warp_size;
CUDA_CHECK(cudaGetDevice(&device));
warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t block_size_best = warp_size;
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
int64_t max_block_size = 256;
if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
max_block_size = 128;
}
for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
if (niter < niter_best) {
niter_best = niter;
block_size_best = block_size;
}
}
const int smem = warp_size*sizeof(float);
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
mul_mat_vec_mxfp4<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 64: {
mul_mat_vec_mxfp4<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 96: {
mul_mat_vec_mxfp4<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 128: {
mul_mat_vec_mxfp4<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 160: {
mul_mat_vec_mxfp4<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 192: {
mul_mat_vec_mxfp4<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 224: {
mul_mat_vec_mxfp4<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 256: {
mul_mat_vec_mxfp4<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
static void mul_mat_vec_cuda_mxfp4(
const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) {
launch_mul_mat_vec_cuda_mxfp4<float>
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
}
void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS;
const size_t ts_src0 = ggml_type_size(src0->type);
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
GGML_ASSERT(ne13 == ne3);
// GGML_ASSERT( nb00 == ts_src0); // TODO adjust for block sizing logic
GGML_ASSERT( nb10 == ts_src1);
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
GGML_ASSERT( nb0 == ts_dst);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
const float * src1_d = (const float *) src1->data;
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
float * dst_d = (float *) dst->data;
const int64_t stride_row = src0->nb[1] / ts_src0;
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s1 = dst->nb[1] / ts_dst;
const int64_t stride_channel_x = src0->nb[2] / ts_src0;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s2 = dst->nb[2] / ts_dst;
const int64_t stride_sample_x = src0->nb[3] / ts_src0;
const int64_t stride_sample_y = src1->nb[3] / ts_src1;
const int64_t stride_sample_dst = dst->nb[3] / ts_dst;
const int64_t nsamples_dst = ne3;
const int64_t nsamples_x = ne03;
const int64_t nchannels_x = ne02;
const int64_t nrows = ne01;
const int64_t ncols = ne00;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_y = ids ? ne11 : ne12;
const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
GGML_ASSERT(ncols_dst == 1);
const block_mxfp4 * src0_d = (const block_mxfp4 *) src0->data;
mul_mat_vec_cuda_mxfp4(src0_d, src1_d, ids_d, dst_d, ncols, nrows, stride_row,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, ctx.stream());
}
void ggml_cuda_op_mul_mat_vec_mxfp4(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t row_diff = row_high - row_low;
GGML_ASSERT(src1_ncols == 1);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
// ggml_cuda_op provides single, contiguous matrices
const int64_t stride_row = ne00 / MXFP4;
const int64_t nchannels_x = 1;
const int64_t nchannels_y = 1;
const int64_t nchannels_dst = 1;
const int64_t stride_channel_x = 0;
const int64_t stride_channel_y = 0;
const int64_t stride_channel_dst = 0;
const int64_t nsamples_x = 1;
const int64_t nsamples_dst = 1;
const int64_t stride_sample_x = 0;
const int64_t stride_sample_y = 0;
const int64_t stride_sample_dst = 0;
const block_mxfp4 * src0_d = (const block_mxfp4 *) src0_dd_i;
mul_mat_vec_cuda_mxfp4(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
GGML_UNUSED(ctx);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_ddq_i);
GGML_UNUSED(src1_ncols);
GGML_UNUSED(src1_padded_row_size);
}
#include "common.cuh"
void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_op_mul_mat_vec_mxfp4(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
...@@ -421,6 +421,13 @@ typedef struct { ...@@ -421,6 +421,13 @@ typedef struct {
} block_iq4_xs; } block_iq4_xs;
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
#define MXFP4 32
typedef struct {
uint8_t d; // scale E8M0 float
uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float
} block_mxfp4;
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding");
#endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL
#endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL
...@@ -1929,6 +1936,9 @@ GGML_TABLE_END() ...@@ -1929,6 +1936,9 @@ GGML_TABLE_END()
#define N_R0_IQ4_XS 2 #define N_R0_IQ4_XS 2
#define N_SG_IQ4_XS 2 #define N_SG_IQ4_XS 2
#define N_R0_MXFP4 4
#define N_SG_MXFP4 2
// kernel argument structs // kernel argument structs
// //
// - element counters (e.g. ne00) typically use int32_t to reduce register usage // - element counters (e.g. ne00) typically use int32_t to reduce register usage
...@@ -4380,16 +4390,16 @@ void mul_vec_q_n_f32_impl( ...@@ -4380,16 +4390,16 @@ void mul_vec_q_n_f32_impl(
device const char * src1, device const char * src1,
device char * dst, device char * dst,
threadgroup char * shmem, threadgroup char * shmem,
uint3 tgpig, uint3 tgpig, // Threadgroup Position in Grid
ushort tiisg, ushort tiisg, // Thread Index in SIMD Group
ushort sgitg) { ushort sgitg) { // SIMD Group Index in ThreadGroup
const int nb = args.ne00/QK4_0; const int nb = args.ne00/QK4_0; // src0->ne[0] / 32
const int r0 = tgpig.x; const int r0 = tgpig.x;
const int r1 = tgpig.y; const int r1 = tgpig.y;
const int im = tgpig.z; const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0; const int first_row = (r0 * nsg + sgitg) * nr0; // nsg=2 nr0=4
const uint i12 = im%args.ne12; const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12; const uint i13 = im/args.ne12;
...@@ -9222,6 +9232,49 @@ kernel void kernel_mul_mm_id( ...@@ -9222,6 +9232,49 @@ kernel void kernel_mul_mm_id(
} }
} }
template <typename type4x4>
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
float4x4 reg_f;
const ushort dst_bias = 15;
const ushort dst_0p5 = 0x3800;
const ushort dst_m_bits = 10;
const half scale = (half)(as_type<float>(((uint32_t)xb->d) << 23));
// il:0 first 16, il:1 last 16
for (int i = 0; i < 8; i++) {
ushort em0 = xb->qs[il*8 + i] & 0x07;
ushort em1 = xb->qs[il*8 + i] & 0x70;
// float16 values
ushort x0 = (em0 << (dst_m_bits - 1)) | ((xb->qs[il*8 + i] & 0x08) << 12);
ushort x1 = (em1 << (dst_m_bits - 5)) | ((xb->qs[il*8 + i] & 0x80) << 8);
// Three cases:
// x is normal and non-zero: Correct bias
if ((em0 & 0x06) != 0) {
x0 = x0 + ((dst_bias - 1) << dst_m_bits);
}
if ((em1 & 0x60) != 0) {
x1 = x1 + ((dst_bias - 1) << dst_m_bits);
}
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
if (em0 == 0x01) {
x0 = dst_0p5 | (x0 & 0x8000);
}
if (em1 == 0x10) {
x1 = dst_0p5 | (x1 & 0x8000);
}
// x is zero, do nothing
if (isnan(scale)) {
reg_f[i/2][2*(i%2) + 0] = scale;
reg_f[i/2][2*(i%2) + 1] = scale;
} else {
reg_f[i/2][2*(i%2) + 0] = scale * as_type<half>(x0);
reg_f[i/2][2*(i%2) + 1] = scale * as_type<half>(x1);
}
}
reg = (type4x4) reg_f;
}
#define QK_NL 16 #define QK_NL 16
// //
...@@ -9289,6 +9342,8 @@ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_m ...@@ -9289,6 +9342,8 @@ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>; template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
// //
// indirect matrix-matrix multiplication // indirect matrix-matrix multiplication
// //
...@@ -9320,6 +9375,8 @@ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_m ...@@ -9320,6 +9375,8 @@ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_m
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>; template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
// //
// matrix-vector multiplication // matrix-vector multiplication
...@@ -9436,6 +9493,120 @@ kernel void kernel_mul_mv_id( ...@@ -9436,6 +9493,120 @@ kernel void kernel_mul_mv_id(
sgitg); sgitg);
} }
// MXFP32 implementation derived from mul_vec_q_n_f32_impl and block_q_n_dot_y
void mul_mv_mxfp4_f32_impl(
ggml_metal_kargs_mul_mv args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const ushort dst_bias = 15;
const ushort dst_0p5 = 0x3800;
const ushort dst_m_bits = 10;
const int nr0 = N_R0_MXFP4;
const int nsg = N_SG_MXFP4;
const int nw = N_SIMDWIDTH;
const int nb = args.ne00/MXFP4;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const float * y = (device const float *) (src1 + offset1);
// pointers to src0 rows
device const block_mxfp4 * ax[nr0];
for (int row = 0; row < nr0; ++row) {
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_mxfp4 *) ((device char *) src0 + offset0);
}
float yl[16]; // src1 vector cache
float sumf[nr0] = {0.f};
const short ix = (tiisg/2);
const short il = (tiisg%2)*16;
device const float * yb = y + ix*MXFP4 + il;
// each thread in a SIMD group deals with half a block.
for (int ib = ix; ib < nb; ib += nw/2) {
#pragma unroll
for (short row = 0; row < nr0; row++) {
// Processes 16 items
device const block_mxfp4 * qb_curr = ax[row] + ib;
float d = as_type<float>(((uint32_t)(ax[row] + ib)->d) << 23);
// il = 0 or 16
device const uint8_t *qs = ((device const uint8_t *) qb_curr + 1 + il/2);
for (int i = 0; i < 8; ++i) {
ushort em0 = qs[i] & 0x07;
ushort em1 = qs[i] & 0x70;
ushort x0 = (em0 << (dst_m_bits - 1)) | ((qs[i] & 0x08) << 12);
ushort x1 = (em1 << (dst_m_bits - 5)) | ((qs[i] & 0x80) << 8);
// Three cases:
// x is normal and non-zero: Correct bias
if ((em0 & 0x06) != 0) {
x0 = x0 + ((dst_bias - 1) << dst_m_bits);
}
if ((em1 & 0x60) != 0) {
x1 = x1 + ((dst_bias - 1) << dst_m_bits);
}
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
if (em0 == 0x01) {
x0 = dst_0p5 | (x0 & 0x8000);
}
if (em1 == 0x10) {
x1 = dst_0p5 | (x1 & 0x8000);
}
// x is zero, do nothing
if (!isnan(d)) {
sumf[row] += yb[i*2] * as_type<half>(x0) * d
+ yb[i*2+1] * as_type<half>(x1) * d;
} else {
sumf[row] = d;
}
}
}
yb += MXFP4 * 16;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < nr0; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + row < args.ne01) {
dst_f32[first_row + row] = tot;
}
}
}
[[host_name("kernel_mul_mv_mxfp4_f32")]]
kernel void kernel_mul_mv_mxfp4_f32(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t; typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>; template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
...@@ -9465,6 +9636,8 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t ...@@ -9465,6 +9636,8 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_mv_mxfp4_f32_impl>>;
kernel void kernel_pool_2d_max_f32( kernel void kernel_pool_2d_max_f32(
device const float * src0, device const float * src0,
device float * dst, device float * dst,
......
...@@ -65,6 +65,9 @@ ...@@ -65,6 +65,9 @@
#define N_R0_IQ4_XS 2 #define N_R0_IQ4_XS 2
#define N_SG_IQ4_XS 2 #define N_SG_IQ4_XS 2
#define N_R0_MXFP4 4
#define N_SG_MXFP4 2
// kernel argument structs // kernel argument structs
// //
// - element counters (e.g. ne00) typically use int32_t to reduce register usage // - element counters (e.g. ne00) typically use int32_t to reduce register usage
......
...@@ -40,6 +40,7 @@ static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; ...@@ -40,6 +40,7 @@ static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
static struct ggml_backend_reg g_ggml_backend_metal_reg; static struct ggml_backend_reg g_ggml_backend_metal_reg;
static struct ggml_backend_device g_ggml_backend_metal_device; static struct ggml_backend_device g_ggml_backend_metal_device;
// information about a Metal device // information about a Metal device
// note: assumes single GPU device - the default one // note: assumes single GPU device - the default one
// TODO: support multiple GPU devices // TODO: support multiple GPU devices
...@@ -209,6 +210,7 @@ enum ggml_metal_kernel_type { ...@@ -209,6 +210,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
...@@ -288,6 +290,7 @@ enum ggml_metal_kernel_type { ...@@ -288,6 +290,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
...@@ -310,6 +313,7 @@ enum ggml_metal_kernel_type { ...@@ -310,6 +313,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
...@@ -334,6 +338,7 @@ enum ggml_metal_kernel_type { ...@@ -334,6 +338,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
...@@ -934,7 +939,7 @@ static id<MTLLibrary> ggml_metal_load_library(id<MTLDevice> device, bool use_bfl ...@@ -934,7 +939,7 @@ static id<MTLLibrary> ggml_metal_load_library(id<MTLDevice> device, bool use_bfl
MTLCompileOptions * options = [MTLCompileOptions new]; MTLCompileOptions * options = [MTLCompileOptions new];
options.preprocessorMacros = prep; options.preprocessorMacros = prep;
//[options setFastMathEnabled:false]; //[options setFastMathEnabled:false];
metal_library = [device newLibraryWithSource:src options:options error:&error]; metal_library = [device newLibraryWithSource:src options:options error:&error];
...@@ -1157,6 +1162,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ...@@ -1157,6 +1162,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
...@@ -1236,6 +1242,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ...@@ -1236,6 +1242,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
...@@ -1258,6 +1265,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ...@@ -1258,6 +1265,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
...@@ -1282,6 +1290,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ...@@ -1282,6 +1290,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
...@@ -3007,6 +3016,7 @@ static bool ggml_metal_encode_node( ...@@ -3007,6 +3016,7 @@ static bool ggml_metal_encode_node(
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
default: GGML_ABORT("MUL MAT-MAT not implemented"); default: GGML_ABORT("MUL MAT-MAT not implemented");
} }
...@@ -3212,6 +3222,12 @@ static bool ggml_metal_encode_node( ...@@ -3212,6 +3222,12 @@ static bool ggml_metal_encode_node(
smem = 32*sizeof(float); smem = 32*sizeof(float);
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
} break; } break;
case GGML_TYPE_MXFP4:
{
nsg = N_SG_MXFP4;
nr0 = N_R0_MXFP4;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
} break;
default: default:
{ {
GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t); GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
...@@ -3396,6 +3412,7 @@ static bool ggml_metal_encode_node( ...@@ -3396,6 +3412,7 @@ static bool ggml_metal_encode_node(
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break; case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break;
default: GGML_ABORT("MUL_MAT_ID not implemented"); default: GGML_ABORT("MUL_MAT_ID not implemented");
} }
...@@ -3607,6 +3624,12 @@ static bool ggml_metal_encode_node( ...@@ -3607,6 +3624,12 @@ static bool ggml_metal_encode_node(
smem = 32*sizeof(float); smem = 32*sizeof(float);
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
} break; } break;
case GGML_TYPE_MXFP4:
{
nsg = N_SG_MXFP4;
nr0 = N_R0_MXFP4;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
} break;
default: default:
{ {
GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t); GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t);
......
...@@ -1902,16 +1902,16 @@ void mul_vec_q_n_f32_impl( ...@@ -1902,16 +1902,16 @@ void mul_vec_q_n_f32_impl(
device const char * src1, device const char * src1,
device char * dst, device char * dst,
threadgroup char * shmem, threadgroup char * shmem,
uint3 tgpig, uint3 tgpig, // Threadgroup Position in Grid
ushort tiisg, ushort tiisg, // Thread Index in SIMD Group
ushort sgitg) { ushort sgitg) { // SIMD Group Index in ThreadGroup
const int nb = args.ne00/QK4_0; const int nb = args.ne00/QK4_0; // src0->ne[0] / 32
const int r0 = tgpig.x; const int r0 = tgpig.x;
const int r1 = tgpig.y; const int r1 = tgpig.y;
const int im = tgpig.z; const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0; const int first_row = (r0 * nsg + sgitg) * nr0; // nsg=2 nr0=4
const uint i12 = im%args.ne12; const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12; const uint i13 = im/args.ne12;
...@@ -6744,6 +6744,49 @@ kernel void kernel_mul_mm_id( ...@@ -6744,6 +6744,49 @@ kernel void kernel_mul_mm_id(
} }
} }
template <typename type4x4>
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
float4x4 reg_f;
const ushort dst_bias = 15;
const ushort dst_0p5 = 0x3800;
const ushort dst_m_bits = 10;
const half scale = (half)(as_type<float>(((uint32_t)xb->d) << 23));
// il:0 first 16, il:1 last 16
for (int i = 0; i < 8; i++) {
ushort em0 = xb->qs[il*8 + i] & 0x07;
ushort em1 = xb->qs[il*8 + i] & 0x70;
// float16 values
ushort x0 = (em0 << (dst_m_bits - 1)) | ((xb->qs[il*8 + i] & 0x08) << 12);
ushort x1 = (em1 << (dst_m_bits - 5)) | ((xb->qs[il*8 + i] & 0x80) << 8);
// Three cases:
// x is normal and non-zero: Correct bias
if ((em0 & 0x06) != 0) {
x0 = x0 + ((dst_bias - 1) << dst_m_bits);
}
if ((em1 & 0x60) != 0) {
x1 = x1 + ((dst_bias - 1) << dst_m_bits);
}
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
if (em0 == 0x01) {
x0 = dst_0p5 | (x0 & 0x8000);
}
if (em1 == 0x10) {
x1 = dst_0p5 | (x1 & 0x8000);
}
// x is zero, do nothing
if (isnan(scale)) {
reg_f[i/2][2*(i%2) + 0] = scale;
reg_f[i/2][2*(i%2) + 1] = scale;
} else {
reg_f[i/2][2*(i%2) + 0] = scale * as_type<half>(x0);
reg_f[i/2][2*(i%2) + 1] = scale * as_type<half>(x1);
}
}
reg = (type4x4) reg_f;
}
#define QK_NL 16 #define QK_NL 16
// //
...@@ -6811,6 +6854,8 @@ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_m ...@@ -6811,6 +6854,8 @@ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>; template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
// //
// indirect matrix-matrix multiplication // indirect matrix-matrix multiplication
// //
...@@ -6842,6 +6887,8 @@ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_m ...@@ -6842,6 +6887,8 @@ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_m
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>; template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
// //
// matrix-vector multiplication // matrix-vector multiplication
...@@ -6958,6 +7005,120 @@ kernel void kernel_mul_mv_id( ...@@ -6958,6 +7005,120 @@ kernel void kernel_mul_mv_id(
sgitg); sgitg);
} }
// MXFP32 implementation derived from mul_vec_q_n_f32_impl and block_q_n_dot_y
void mul_mv_mxfp4_f32_impl(
ggml_metal_kargs_mul_mv args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const ushort dst_bias = 15;
const ushort dst_0p5 = 0x3800;
const ushort dst_m_bits = 10;
const int nr0 = N_R0_MXFP4;
const int nsg = N_SG_MXFP4;
const int nw = N_SIMDWIDTH;
const int nb = args.ne00/MXFP4;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const float * y = (device const float *) (src1 + offset1);
// pointers to src0 rows
device const block_mxfp4 * ax[nr0];
for (int row = 0; row < nr0; ++row) {
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_mxfp4 *) ((device char *) src0 + offset0);
}
float yl[16]; // src1 vector cache
float sumf[nr0] = {0.f};
const short ix = (tiisg/2);
const short il = (tiisg%2)*16;
device const float * yb = y + ix*MXFP4 + il;
// each thread in a SIMD group deals with half a block.
for (int ib = ix; ib < nb; ib += nw/2) {
#pragma unroll
for (short row = 0; row < nr0; row++) {
// Processes 16 items
device const block_mxfp4 * qb_curr = ax[row] + ib;
float d = as_type<float>(((uint32_t)(ax[row] + ib)->d) << 23);
// il = 0 or 16
device const uint8_t *qs = ((device const uint8_t *) qb_curr + 1 + il/2);
for (int i = 0; i < 8; ++i) {
ushort em0 = qs[i] & 0x07;
ushort em1 = qs[i] & 0x70;
ushort x0 = (em0 << (dst_m_bits - 1)) | ((qs[i] & 0x08) << 12);
ushort x1 = (em1 << (dst_m_bits - 5)) | ((qs[i] & 0x80) << 8);
// Three cases:
// x is normal and non-zero: Correct bias
if ((em0 & 0x06) != 0) {
x0 = x0 + ((dst_bias - 1) << dst_m_bits);
}
if ((em1 & 0x60) != 0) {
x1 = x1 + ((dst_bias - 1) << dst_m_bits);
}
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
if (em0 == 0x01) {
x0 = dst_0p5 | (x0 & 0x8000);
}
if (em1 == 0x10) {
x1 = dst_0p5 | (x1 & 0x8000);
}
// x is zero, do nothing
if (!isnan(d)) {
sumf[row] += yb[i*2] * as_type<half>(x0) * d
+ yb[i*2+1] * as_type<half>(x1) * d;
} else {
sumf[row] = d;
}
}
}
yb += MXFP4 * 16;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
for (int row = 0; row < nr0; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + row < args.ne01) {
dst_f32[first_row + row] = tot;
}
}
}
[[host_name("kernel_mul_mv_mxfp4_f32")]]
kernel void kernel_mul_mv_mxfp4_f32(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t; typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>; template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
...@@ -6987,6 +7148,8 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t ...@@ -6987,6 +7148,8 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_mv_mxfp4_f32_impl>>;
kernel void kernel_pool_2d_max_f32( kernel void kernel_pool_2d_max_f32(
device const float * src0, device const float * src0,
device float * dst, device float * dst,
......
...@@ -4925,6 +4925,144 @@ void quantize_row_iq2_s_ref(const float * GGML_RESTRICT x, block_iq2_s * GGML_RE ...@@ -4925,6 +4925,144 @@ void quantize_row_iq2_s_ref(const float * GGML_RESTRICT x, block_iq2_s * GGML_RE
quantize_iq2_s(x, y, 1, k, NULL); quantize_iq2_s(x, y, 1, k, NULL);
} }
// =============================== mxfp4 (de)-quantization
void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
static const int qk = MXFP4;
static const uint32_t E8_BIAS = 127;
static const uint32_t E2_BIAS = 1;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
for (int j = 0; j < qk; j++) {
const float v = x[i*qk + j];
if (amax < fabsf(v)) {
amax = fabsf(v);
}
}
const float dequant_scale = amax / 6.0f;
uint32_t dequant_scale_exponent = 0;
memcpy(&dequant_scale_exponent, &dequant_scale, sizeof(dequant_scale_exponent));
// Rounding up
dequant_scale_exponent = (dequant_scale_exponent + 0x007FFFFF) & 0x7F800000;
// Rounding down
// dequant_scale_exponent = dequant_scale_exponent & 0x7F800000;
float dequant_scale_rounded = 0.0f;
memcpy(&dequant_scale_rounded, &dequant_scale_exponent, sizeof(dequant_scale_rounded));
float quant_scale = 0.0f;
if (dequant_scale_rounded != 0.0f) {
quant_scale = 1.0f / dequant_scale_rounded;
}
y[i].d = (uint8_t)(dequant_scale_exponent >> 23);
for (int j = 0; j < qk/2; ++j) {
const float x0 = x[i*qk + j*2]*quant_scale;
const float x1 = x[i*qk + j*2+1]*quant_scale;
uint32_t xi0 = 0;
uint32_t xi1 = 0;
memcpy(&xi0, &x0, sizeof(xi0));
memcpy(&xi1, &x1, sizeof(xi1));
uint32_t s0 = xi0 & 0x80000000;
uint32_t s1 = xi1 & 0x80000000;
uint32_t e0 = (xi0 >> 23) & 0xFF;
uint32_t e1 = (xi1 >> 23) & 0xFF;
uint32_t m0 = (xi0 & 0x7FFFFF);
uint32_t m1 = (xi1 & 0x7FFFFF);
// 0.25 <= x < 0.75 maps to 0.5, a denormal number
// Move implicit bit 1 at the beginning to mantissa for denormals
// adjusted_exponents
uint32_t ae0 = E8_BIAS - (e0 + 1);
uint32_t ae1 = E8_BIAS - (e1 + 1);
if (e0 < E8_BIAS) {
m0 = (0x400000 | (m0 >> 1)) >> ae0;
}
if (e1 < E8_BIAS) {
m1 = (0x400000 | (m1 >> 1)) >> ae1;
}
// For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
e0 = MAX(e0, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS);
e1 = MAX(e1, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS);
// Combine sign, exponent, and mantissa, while saturating
// rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
uint32_t tmp0 = MIN((((e0 << 2) | (m0 >> 21)) + 1) >> 1, 0x7);
uint32_t tmp1 = MIN((((e1 << 2) | (m1 >> 21)) + 1) >> 1, 0x7);
uint8_t v0 = (uint8_t)((s0 >> 28) | tmp0);
uint8_t v1 = (uint8_t)((s1 >> 28) | tmp1);
y[i].qs[j] = v0;
y[i].qs[j] |= v1 << 4;
}
}
}
void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
assert(k % MXFP4 == 0);
const int nb = k / MXFP4;
const uint16_t dst_bias = 15;
const uint16_t dst_0p5 = 0x3800;
const uint16_t dst_m_bits = 10;
for (int i = 0; i < nb; i++) {
union {
uint32_t as_bits;
float as_value;
} scale;
scale.as_bits = (((uint32_t)x[i].d) << 23);
for (int j = 0; j < MXFP4/2; ++j) {
uint16_t em0 = x[i].qs[j] & 0x07;
uint16_t em1 = x[i].qs[j] & 0x70;
// float16 values
uint16_t x0 = (em0 << (dst_m_bits - 1)) | ((x[i].qs[j] & 0x08) << 12);
uint16_t x1 = (em1 << (dst_m_bits - 5)) | ((x[i].qs[j] & 0x80) << 8);
// Three cases:
// x is normal and non-zero: Correct bias
if ((em0 & 0x06) != 0) {
x0 = x0 + ((dst_bias - 1) << dst_m_bits);
}
if ((em1 & 0x60) != 0) {
x1 = x1 + ((dst_bias - 1) << dst_m_bits);
}
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
if (em0 == 0x01) {
x0 = dst_0p5 | (x0 & 0x8000);
}
if (em1 == 0x10) {
x1 = dst_0p5 | (x1 & 0x8000);
}
// x is zero, do nothing
if (isnan(scale.as_value)) {
y[i*MXFP4 + j*2] = scale.as_value;
y[i*MXFP4 + j*2+1] = scale.as_value;
} else {
y[i*MXFP4 + j*2] = GGML_FP16_TO_FP32(x0)*scale.as_value;
y[i*MXFP4 + j*2+1] = GGML_FP16_TO_FP32(x1)*scale.as_value;
}
}
}
}
size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
}
// =============================== data validation // =============================== data validation
static bool validate_float(float f, size_t i) { static bool validate_float(float f, size_t i) {
...@@ -5214,7 +5352,9 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte ...@@ -5214,7 +5352,9 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{ {
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
} break; } break;
case GGML_TYPE_MXFP4:
// TODO - anything to validate?
break;
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
case GGML_TYPE_I32: case GGML_TYPE_I32:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment