Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
aa43da4b
Commit
aa43da4b
authored
Aug 03, 2025
by
Daniel Hiltgen
Committed by
Michael Yang
Aug 04, 2025
Browse files
cuda: optimize memory access
Read 4 bytes at a time (8 elements) when performing mul_mat_vec_mxfp4
parent
0ac1c0d3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
46 deletions
+72
-46
ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cu
ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cu
+72
-46
No files found.
ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cu
View file @
aa43da4b
...
...
@@ -10,8 +10,8 @@ typedef union {
template
<
typename
type_acc
,
int
block_size
>
// TODO type_acc unused - consider bf16 support
static
__global__
void
mul_mat_vec_mxfp4
(
const
block_mxfp4
*
__restrict__
x
,
const
float
*
__restrict__
y
,
const
int32_t
*
__restrict__
ids
,
float
*
__restrict__
dst
,
const
int64_t
ncols
2
,
const
int64_t
nchannels_y
,
const
int64_t
stride_row
,
const
block_mxfp4
*
__restrict__
x
_base
,
const
float
*
__restrict__
y
,
const
int32_t
*
__restrict__
ids
,
float
*
__restrict__
dst
,
const
int64_t
ncols
,
const
int64_t
nchannels_y
,
const
int64_t
stride_row
,
const
int64_t
channel_ratio
,
const
int64_t
stride_channel_x
,
const
int64_t
stride_channel_y
,
const
int64_t
stride_channel_dst
,
const
int64_t
sample_ratio
,
const
int64_t
stride_sample_x
,
const
int64_t
stride_sample_y
,
const
int64_t
stride_sample_dst
)
{
const
int64_t
row
=
blockIdx
.
x
;
...
...
@@ -23,16 +23,20 @@ static __global__ void mul_mat_vec_mxfp4(
const
int64_t
sample_y
=
sample_dst
;
const
int
tid
=
threadIdx
.
x
;
constexpr
int
warp_size
=
ggml_cuda_get_physical_warp_size
();
const
int64_t
ncols8
=
ncols
/
8
;
const
uint16_t
dst_bias
=
15
;
const
uint16_t
dst_0p5
=
0x3800
;
const
uint16_t
dst_m_bits
=
10
;
x
+=
sample_x
*
stride_sample_x
+
channel_x
*
stride_channel_x
+
row
*
stride_row
;
y
+=
sample_y
*
stride_sample_y
+
channel_y
*
stride_channel_y
;
// x_base is offset by blocks of 32 elements
x_base
+=
sample_x
*
stride_sample_x
+
channel_x
*
stride_channel_x
+
row
*
stride_row
;
// y is offset by elements
y
+=
sample_y
*
stride_sample_y
+
channel_y
*
stride_channel_y
;
// dst is offset by elements
dst
+=
sample_dst
*
stride_sample_dst
+
channel_dst
*
stride_channel_dst
;
const
float
2
*
y
2
=
(
const
float
2
*
)
y
;
const
float
4
*
y
4
=
(
const
float
4
*
)
y
;
extern
__shared__
char
data_mmv
[];
// allocated in GPU shared memory: warp_size*sizeof(float)
float
*
buf_iw
=
(
float
*
)
data_mmv
;
...
...
@@ -46,50 +50,72 @@ static __global__ void mul_mat_vec_mxfp4(
float
sumf
=
0.0
f
;
for
(
int64_t
col2
=
tid
;
col2
<
ncols2
;
col2
+=
block_size
)
{
int
offset0
=
col2
/
(
MXFP4
/
2
);
int
i
=
col2
%
(
MXFP4
/
2
);
const
block_mxfp4
*
x2
=
x
+
offset0
;
// each i8 index proceses 8 items at a time
for
(
int64_t
i8
=
tid
;
i8
<
ncols8
;
i8
+=
block_size
)
{
// As i8 indexes past a block, we have to offset further
int
offset0
=
i8
/
(
MXFP4
/
8
);
int
xi
=
(
i8
%
(
MXFP4
/
8
))
*
4
;
// jump 4 bytes for each 8 elements
const
block_mxfp4
*
x
=
x_base
+
offset0
;
union
{
uint32_t
as_bits
;
float
as_value
;
}
scale
;
scale
.
as_bits
=
(((
uint32_t
)
x2
->
d
)
<<
23
);
uint16_t
em0
=
x2
->
qs
[
i
]
&
0x07
;
uint16_t
em1
=
x2
->
qs
[
i
]
&
0x70
;
// float16 values
f16_t
x0
;
f16_t
x1
;
x0
.
u16
=
(
em0
<<
(
dst_m_bits
-
1
))
|
((
x2
->
qs
[
i
]
&
0x08
)
<<
12
);
x1
.
u16
=
(
em1
<<
(
dst_m_bits
-
5
))
|
((
x2
->
qs
[
i
]
&
0x80
)
<<
8
);
scale
.
as_bits
=
(((
uint32_t
)
x
->
d
)
<<
23
);
if
(
isnan
(
scale
.
as_value
))
{
sumf
=
scale
.
as_value
;
break
;
}
const
uint8_t
qs
[
4
]
=
{
(
uint8_t
)(
x
->
qs
[
xi
]),
(
uint8_t
)(
x
->
qs
[
xi
+
1
]),
(
uint8_t
)(
x
->
qs
[
xi
+
2
]),
(
uint8_t
)(
x
->
qs
[
xi
+
3
])
};
const
uint8_t
el
[
8
]
=
{
(
uint8_t
)(
qs
[
0
]
&
0xf
),
(
uint8_t
)((
qs
[
0
]
&
0xf0
)
>>
4
),
(
uint8_t
)(
qs
[
1
]
&
0xf
),
(
uint8_t
)((
qs
[
1
]
&
0xf0
)
>>
4
),
(
uint8_t
)(
qs
[
2
]
&
0xf
),
(
uint8_t
)((
qs
[
2
]
&
0xf0
)
>>
4
),
(
uint8_t
)(
qs
[
3
]
&
0xf
),
(
uint8_t
)((
qs
[
3
]
&
0xf0
)
>>
4
)
};
uint16_t
em
[
8
];
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
em
[
i
]
=
(
uint16_t
)(
el
[
i
]
&
0x07
);
}
// float16 values
f16_t
x4u
[
8
];
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
x4u
[
i
].
u16
=
(
em
[
i
]
<<
(
dst_m_bits
-
1
))
|
((
el
[
i
]
&
0x08
)
<<
12
);
}
// Three cases:
// x is normal and non-zero: Correct bias
if
((
em0
&
0x06
)
!=
0
)
{
x0
.
u16
=
x0
.
u16
+
((
dst_bias
-
1
)
<<
dst_m_bits
);
}
if
((
em1
&
0x60
)
!=
0
)
{
x1
.
u16
=
x1
.
u16
+
((
dst_bias
-
1
)
<<
dst_m_bits
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
if
((
em
[
i
]
&
0x06
)
!=
0
)
{
x4u
[
i
].
u16
=
x4u
[
i
].
u16
+
((
dst_bias
-
1
)
<<
dst_m_bits
);
}
}
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
if
(
em0
==
0x01
)
{
x0
.
u16
=
dst_0p5
|
(
x0
.
u16
&
0x8000
);
}
if
(
em1
==
0x10
)
{
x1
.
u16
=
dst_0p5
|
(
x1
.
u16
&
0x8000
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
if
(
em
[
i
]
==
0x01
)
{
x4u
[
i
].
u16
=
dst_0p5
|
(
x4u
[
i
].
u16
&
0x8000
);
}
}
// x is zero, do nothing
if
(
isnan
(
scale
.
as_value
))
{
sumf
=
scale
.
as_value
;
break
;
}
const
float2
tmpx
=
{
x0
.
f16
,
x1
.
f16
};
const
float2
tmpy
=
y2
[
col2
];
sumf
+=
tmpx
.
x
*
tmpy
.
x
*
scale
.
as_value
;
sumf
+=
tmpx
.
y
*
tmpy
.
y
*
scale
.
as_value
;
const
float
scalef
=
scale
.
as_value
;
const
float4
tmpx0
=
{
x4u
[
0
].
f16
,
x4u
[
1
].
f16
,
x4u
[
2
].
f16
,
x4u
[
3
].
f16
};
const
float4
tmpx1
=
{
x4u
[
4
].
f16
,
x4u
[
5
].
f16
,
x4u
[
6
].
f16
,
x4u
[
7
].
f16
};
const
float4
tmpy0
=
y4
[
i8
*
2
];
const
float4
tmpy1
=
y4
[
i8
*
2
+
1
];
sumf
+=
tmpx0
.
x
*
tmpy0
.
x
*
scalef
;
sumf
+=
tmpx0
.
y
*
tmpy0
.
y
*
scalef
;
sumf
+=
tmpx0
.
z
*
tmpy0
.
z
*
scalef
;
sumf
+=
tmpx0
.
w
*
tmpy0
.
w
*
scalef
;
sumf
+=
tmpx1
.
x
*
tmpy1
.
x
*
scalef
;
sumf
+=
tmpx1
.
y
*
tmpy1
.
y
*
scalef
;
sumf
+=
tmpx1
.
z
*
tmpy1
.
z
*
scalef
;
sumf
+=
tmpx1
.
w
*
tmpy1
.
w
*
scalef
;
}
sumf
=
warp_reduce_sum
<
warp_size
>
(
sumf
);
...
...
@@ -151,42 +177,42 @@ static void launch_mul_mat_vec_cuda_mxfp4(
switch
(
block_size_best
)
{
case
32
:
{
mul_mat_vec_mxfp4
<
type_acc
,
32
><<<
block_nums
,
block_dims
,
smem
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
(
x
,
y
,
ids
,
dst
,
ncols
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
64
:
{
mul_mat_vec_mxfp4
<
type_acc
,
64
><<<
block_nums
,
block_dims
,
smem
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
(
x
,
y
,
ids
,
dst
,
ncols
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
96
:
{
mul_mat_vec_mxfp4
<
type_acc
,
96
><<<
block_nums
,
block_dims
,
smem
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
(
x
,
y
,
ids
,
dst
,
ncols
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
128
:
{
mul_mat_vec_mxfp4
<
type_acc
,
128
><<<
block_nums
,
block_dims
,
smem
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
(
x
,
y
,
ids
,
dst
,
ncols
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
160
:
{
mul_mat_vec_mxfp4
<
type_acc
,
160
><<<
block_nums
,
block_dims
,
smem
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
(
x
,
y
,
ids
,
dst
,
ncols
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
192
:
{
mul_mat_vec_mxfp4
<
type_acc
,
192
><<<
block_nums
,
block_dims
,
smem
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
(
x
,
y
,
ids
,
dst
,
ncols
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
224
:
{
mul_mat_vec_mxfp4
<
type_acc
,
224
><<<
block_nums
,
block_dims
,
smem
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
(
x
,
y
,
ids
,
dst
,
ncols
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
256
:
{
mul_mat_vec_mxfp4
<
type_acc
,
256
><<<
block_nums
,
block_dims
,
smem
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
(
x
,
y
,
ids
,
dst
,
ncols
,
nchannels_y
,
stride_row
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
default:
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment