Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
3a9e8e9f
Unverified
Commit
3a9e8e9f
authored
Nov 12, 2025
by
Daniel Hiltgen
Committed by
GitHub
Nov 12, 2025
Browse files
vulkan: temporary cary of vulkan fixes (#12971)
This should be reverted once we update ggml past b6897
parent
cb1cb064
Changes
32
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
855 additions
and
385 deletions
+855
-385
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
...kend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
+1
-71
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
...gml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
+7
-7
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
.../ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
+70
-0
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
...end/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
+71
-217
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
...ml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
+505
-33
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
...l/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
+78
-0
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
...d/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
+2
-0
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
...d/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
+10
-3
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
...d/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
+10
-3
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
...nd/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
+61
-29
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
...ckend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
+33
-20
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
...gml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+7
-2
No files found.
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
View file @
3a9e8e9f
...
...
@@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[BN];
uint _ne1;
#ifdef MUL_MAT_ID_USE_SUBGROUPS
shared uvec4 ballots_sh[NUM_WARPS];
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_LocalInvocationIndex;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
ballots_sh[gl_SubgroupID] = ballot;
barrier();
uint subgroup_base = 0;
uint total = 0;
for (uint k = 0; k < gl_NumSubgroups; ++k) {
if (k == gl_SubgroupID) {
subgroup_base = total;
}
total += subgroupBallotBitCount(ballots_sh[k]);
}
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
}
_ne1 += total;
iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
break;
}
}
barrier();
}
#endif // MUL_MAT_ID_USE_SUBGROUPS
#endif // MUL_MAT_ID
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
#include "mul_mm_id_funcs.glsl"
#include "mul_mm_funcs.glsl"
void main() {
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
View file @
3a9e8e9f
...
...
@@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const
uint
ib
=
idx
/
128
;
// 2 values per idx
const
uint
iqs
=
idx
%
128
;
// 0..127
const
uint
qsi
=
(
iqs
/
64
)
*
32
+
(
iqs
%
16
)
*
2
;
// 0,2,4..30
const
uint
qsi
=
(
iqs
/
64
)
*
16
+
(
iqs
%
16
)
;
// 0..15
const
uint
scalesi
=
iqs
/
8
;
// 0..15
const
uint
qsshift
=
((
iqs
%
64
)
/
16
)
*
2
;
// 0,2,4,6
const
uvec2
qs
=
uvec2
(
data_a
[
ib
].
qs
[
qsi
],
data_a
[
ib
].
qs
[
qsi
+
1
]);
const
uvec2
qs
=
uvec2
(
unpack8
(
data_a_packed16
[
ib
].
qs
[
qsi
]
)
);
const
uint
scales
=
data_a
[
ib
].
scales
[
scalesi
];
const
vec2
d
=
vec2
(
data_a
[
ib
].
d
);
const
vec2
d
m
=
vec2
(
data_a
[
ib
].
d
m
);
const
vec2
v
=
d
.
x
*
float
(
scales
&
0xF
)
*
vec2
((
qs
>>
qsshift
)
&
3
)
-
d
.
y
*
float
(
scales
>>
4
);
const
vec2
v
=
d
m
.
x
*
float
(
scales
&
0xF
)
*
vec2
((
qs
>>
qsshift
)
&
3
)
-
d
m
.
y
*
float
(
scales
>>
4
);
buf_a
[
buf_idx
]
=
FLOAT_TYPE_VEC2
(
v
.
xy
);
#elif defined(DATA_A_Q3_K)
...
...
@@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const
uint
is
=
2
*
n
+
b
;
// 0..7
const
uint
qsi
=
n
*
32
+
(
iqs
%
16
)
*
2
;
// 0,2,4..126
const
vec2
loadd
=
vec2
(
data_a
[
ib
].
d
);
const
vec2
loadd
=
vec2
(
data_a
[
ib
].
d
m
);
const
uint
scidx0
=
(
is
<
4
)
?
is
:
(
is
+
4
);
const
uint
scidx1
=
(
is
<
4
)
?
is
:
(
is
-
4
);
...
...
@@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const
uint8_t
hm
=
uint8_t
(
1
<<
(
iqs
/
16
));
const
vec2
loadd
=
vec2
(
data_a
[
ib
].
d
);
const
vec2
loadd
=
vec2
(
data_a
[
ib
].
d
m
);
const
uint
scidx0
=
(
is
<
4
)
?
is
:
(
is
+
4
);
const
uint
scidx1
=
(
is
<
4
)
?
is
:
(
is
-
4
);
...
...
@@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const
uint
ib
=
idx
/
8
;
const
uint
iqs
=
(
idx
&
0x07
)
*
2
;
const
float
d
=
e8m0_to_fp32
(
data_a
[
ib
].
e
);
const
float
d
=
e8m0_to_fp32
(
data_a
[
ib
].
e
)
*
0
.
5
;
const
uint
vui
=
uint
(
data_a
[
ib
].
qs
[
iqs
]);
const
uint
vui2
=
uint
(
data_a
[
ib
].
qs
[
iqs
+
1
]);
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
0 → 100644
View file @
3a9e8e9f
#ifdef MUL_MAT_ID
shared
u16vec2
row_ids
[
BN
];
uint
_ne1
;
#ifdef MUL_MAT_ID_USE_SUBGROUPS
shared
uvec4
ballots_sh
[
NUM_WARPS
];
void
load_row_ids
(
uint
expert_idx
,
bool
nei0_is_pow2
,
uint
ic
)
{
_ne1
=
0
;
uint
num_elements
=
p
.
nei1
*
p
.
nei0
;
uint
nei0shift
=
findLSB
(
p
.
nei0
);
uint
ids
[
16
];
uint
iter
=
0
;
for
(
uint
j
=
0
;
j
<
num_elements
;
j
+=
BLOCK_SIZE
)
{
// prefetch up to 16 elements
if
(
iter
==
0
)
{
[[
unroll
]]
for
(
uint
k
=
0
;
k
<
16
;
++
k
)
{
uint
i
=
j
+
gl_LocalInvocationIndex
+
k
*
BLOCK_SIZE
;
bool
in_range
=
i
<
num_elements
;
uint
ii1
;
if
(
nei0_is_pow2
)
{
ii1
=
i
>>
nei0shift
;
}
else
{
ii1
=
i
/
p
.
nei0
;
}
uint
ii0
=
i
-
ii1
*
p
.
nei0
;
ids
[
k
]
=
in_range
?
data_ids
[
ii1
*
p
.
nbi1
+
ii0
]
:
0
;
}
}
uint
i
=
j
+
gl_LocalInvocationIndex
;
bool
in_range
=
i
<
num_elements
;
uint
ii1
;
if
(
nei0_is_pow2
)
{
ii1
=
i
>>
nei0shift
;
}
else
{
ii1
=
i
/
p
.
nei0
;
}
uint
ii0
=
i
-
ii1
*
p
.
nei0
;
uint
id
=
ids
[
iter
++
];
uvec4
ballot
=
subgroupBallot
(
in_range
&&
id
==
expert_idx
);
ballots_sh
[
gl_SubgroupID
]
=
ballot
;
barrier
();
uint
subgroup_base
=
0
;
uint
total
=
0
;
for
(
uint
k
=
0
;
k
<
gl_NumSubgroups
;
++
k
)
{
if
(
k
==
gl_SubgroupID
)
{
subgroup_base
=
total
;
}
total
+=
subgroupBallotBitCount
(
ballots_sh
[
k
]);
}
barrier
();
uint
idx
=
subgroup_base
+
subgroupBallotExclusiveBitCount
(
ballot
);
if
(
in_range
&&
id
==
expert_idx
&&
_ne1
+
idx
>=
ic
*
BN
&&
_ne1
+
idx
<
(
ic
+
1
)
*
BN
)
{
row_ids
[
_ne1
+
idx
-
ic
*
BN
]
=
u16vec2
(
ii0
,
ii1
);
}
_ne1
+=
total
;
iter
&=
15
;
if
(
_ne1
>=
(
ic
+
1
)
*
BN
)
{
break
;
}
}
barrier
();
}
#endif // MUL_MAT_ID_USE_SUBGROUPS
#endif // MUL_MAT_ID
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
View file @
3a9e8e9f
...
...
@@ -10,10 +10,9 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#if defined(MUL_MAT_ID_USE_SUBGROUPS)
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#endif
#ifdef MUL_MAT_ID
...
...
@@ -24,7 +23,10 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
...
...
@@ -76,40 +78,31 @@ layout (constant_id = 10) const uint WARP = 32;
#define BK 32
#ifdef COOPMAT
#define SHMEM_STRIDE (BK / 4 + 4)
#else
#define SHMEM_STRIDE (BK / 4 + 1)
#endif
#define MMQ_SHMEM
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
#include "mul_mmq_shmem_types.glsl"
#ifndef COOPMAT
#if QUANT_AUXF == 1
shared FLOAT_TYPE buf_a_dm[BM];
#ifdef MUL_MAT_ID
#define BK_STEP 1
#else
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
#ifndef BK_STEP
#define BK_STEP 4
#endif
#endif
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
#ifndef COOPMAT
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
#endif
// Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP];
shared block_b_cache buf_b[BN * BK_STEP];
// Register cache
block_a_cache cache_a[WMITER * TM];
block_b_cache cache_b;
#define LOAD_VEC_A (4 * QUANT_R)
#define LOAD_VEC_A (4 * QUANT_R
_MMQ
)
#define LOAD_VEC_B 16
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
#include "mul_mm_id_funcs.glsl"
#include "mul_mmq_funcs.glsl"
void main() {
...
...
@@ -139,26 +132,12 @@ void main() {
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER;
#ifdef COOPMAT
const uint warp_i = gl_SubgroupID;
const uint tiw = gl_SubgroupInvocationID;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint storestride = WARP / TM;
const uint store_r = tiw % TM;
const uint store_c = tiw / TM;
#else
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM);
#endif
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
...
...
@@ -172,17 +151,27 @@ void main() {
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
uint _ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
#ifdef MUL_MAT_ID_USE_SUBGROUPS
if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true, ic);
} else {
load_row_ids(expert_idx, false, ic);
}
#else
_ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
row_ids[_ne1] = u16vec2(ii0, ii1);
if (_ne1 >= ic * BN) {
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
}
_ne1++;
}
}
}
barrier();
#endif
// Workgroup has no work
if (ic * BN >= _ne1) return;
...
...
@@ -209,159 +198,70 @@ void main() {
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
#endif
#ifdef COOPMAT
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
int32_t cache_a_qs[WMITER * TM * BK / 4];
int32_t cache_b_qs[TN * BK / 4];
ACC_TYPE sums[WMITER * TM * WNITER * TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = ACC_TYPE(0.0f);
}
#endif
#if QUANT_AUXF == 1
FLOAT_TYPE cache_a_dm[WMITER * TM];
#else
FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
#endif
FLOAT_TYPE_VEC2 cache_b_ds[TN];
for (uint block = start_k; block < end_k; block += BK) {
for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
const uint iqs = loadr_a;
const uint buf_ib = loadc_a + l;
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
const uint iqs = loadr_a;
if (iqs == 0) {
#if QUANT_AUXF == 1
buf_a_dm[buf_ib] = get_d(ib);
#else
buf_a_dm[buf_ib] = get_dm(ib);
#endif
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
}
#if QUANT_R == 1
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
#else
const i32vec2 vals = repack(ib, iqs);
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
#endif
}
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
const uint buf_ib = loadc_b + l;
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
const uint ib = idx / 8;
const uint iqs = idx & 0x7;
const u16vec2 row_idx = row_ids[buf_ib];
const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
#else
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
const uint iqs = loadr_b;
const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
#endif
const uint iqs = loadr_b;
const uint buf_ib = loadc_b + l;
if (iqs == 0) {
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
}
barrier();
pos_a_ib +=
1
;
pos_b_ib +=
1
;
pos_a_ib +=
BK_STEP
;
pos_b_ib +=
BK_STEP
;
#ifdef COOPMAT
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
const uint ib_a = warp_r * WM + cm_row * TM;
// Load from shared into cache
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
// TODO: only cache values that are actually needed
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
}
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint ib_b = warp_c * WN + cm_col * TN;
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
// TODO: only cache values that are actually needed
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
}
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
}
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
}
}
#else
for (uint k_step = 0; k_step < BK_STEP; k_step++) {
// Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
}
const uint reg_ib = wsir * TM + cr;
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
block_a_to_registers(reg_ib, k_step * BM + buf_ib);
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
cache_b_ds[cc] = buf_b_ds[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
}
}
const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
block_b_to_registers(ib);
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint cache_a_idx = wsir * TM + cr;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
int32_t q_sum = 0;
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
cache_b_qs[cc * (BK / 4) + idx_k]);
}
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
sums[sums_idx] += mmq_dot_product(cache_a_idx);
}
}
}
}
}
#endif
barrier();
}
...
...
@@ -373,54 +273,6 @@ void main() {
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif
#ifdef COOPMAT
#ifdef MUL_MAT_ID
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
#else
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
if (is_aligned && is_in_bounds) {
// Full coopMat is within bounds and stride_d is aligned with 16B
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
} else if (is_in_bounds) {
// Full coopMat is within bounds, but stride_d is not aligned
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
// Partial coopMat is within bounds
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
}
}
#endif // MUL_MAT_ID
#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
...
...
@@ -431,19 +283,21 @@ void main() {
const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
const u16vec2 row_idx = row_ids[row_i
- ic * BN
];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
#ifdef MUL_MAT_ID
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
if (dr_warp + cr < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
}
#else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[
(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]
);
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[
sums_idx].x
);
}
#endif // MUL_MAT_ID
}
}
}
}
#endif // COOPMAT
}
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
View file @
3a9e8e9f
...
...
@@ -6,41 +6,89 @@
// Each iqs value maps to a 32-bit integer
#if defined(DATA_A_Q4_0)
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
// 2-byte loads for Q4_0 blocks (18 bytes)
// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2
repack
(
uint
ib
,
uint
iqs
)
{
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
const
u16vec2
quants
=
u16vec2
(
data_a
[
ib
].
qs
[
iqs
*
2
],
data_a
[
ib
].
qs
[
iqs
*
2
+
1
]);
#ifdef DATA_A_Q4_0
const
u16vec2
quants
=
u16vec2
(
data_a
_packed16
[
ib
].
qs
[
iqs
*
2
],
data_a
_packed16
[
ib
].
qs
[
iqs
*
2
+
1
]);
const
uint32_t
vui
=
pack32
(
quants
);
return
i32vec2
(
vui
&
0x0F0F0F0F
,
(
vui
>>
4
)
&
0x0F0F0F0F
);
#else // DATA_A_Q4_1
const
uint32_t
vui
=
data_a_packed32
[
ib
].
qs
[
iqs
];
return
i32vec2
(
vui
&
0x0F0F0F0F
,
(
vui
>>
4
)
&
0x0F0F0F0F
);
#endif
}
#ifdef DATA_A_Q4_0
ACC_TYPE
mul_q8_1
(
const
int32_t
q_sum
,
const
float
da
,
const
vec2
dsb
,
const
int32_t
sum_divisor
)
{
return
ACC_TYPE
(
da
*
(
float
(
q_sum
)
*
dsb
.
x
-
(
8
/
sum_divisor
)
*
dsb
.
y
));
}
#else // DATA_A_Q4_1
ACC_TYPE
mul_q8_1
(
const
int32_t
q_sum
,
const
vec2
dma
,
const
vec2
dsb
,
const
int32_t
sum_divisor
)
{
return
ACC_TYPE
(
float
(
q_sum
)
*
dma
.
x
*
dsb
.
x
+
dma
.
y
*
dsb
.
y
/
sum_divisor
);
}
#endif
#if defined(DATA_A_Q4_1)
i32vec2
repack
(
uint
ib
,
uint
iqs
)
{
// Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
const
uint32_t
vui
=
data_a_packed32
[
ib
].
qs
[
iqs
];
return
i32vec2
(
vui
&
0x0F0F0F0F
,
(
vui
>>
4
)
&
0x0F0F0F0F
);
#ifdef MMQ_SHMEM
void
block_a_to_shmem
(
const
uint
buf_ib
,
const
uint
ib
,
const
uint
iqs
)
{
#ifdef DATA_A_Q4_0
buf_a
[
buf_ib
].
qs
[
iqs
]
=
pack32
(
u16vec2
(
data_a_packed16
[
ib
].
qs
[
iqs
*
2
],
data_a_packed16
[
ib
].
qs
[
iqs
*
2
+
1
]));
if
(
iqs
==
0
)
{
buf_a
[
buf_ib
].
dm
=
FLOAT_TYPE
(
data_a_packed16
[
ib
].
d
);
}
#else // DATA_A_Q4_1
buf_a
[
buf_ib
].
qs
[
iqs
]
=
data_a_packed32
[
ib
].
qs
[
iqs
];
if
(
iqs
==
0
)
{
buf_a
[
buf_ib
].
dm
=
FLOAT_TYPE_VEC2
(
data_a_packed32
[
ib
].
dm
);
}
#endif
}
ACC_TYPE
mul_q8_1
(
const
int32_t
q_sum
,
const
vec2
dma
,
const
vec2
dsb
,
const
int32_t
sum_divisor
)
{
return
ACC_TYPE
(
float
(
q_sum
)
*
dma
.
x
*
dsb
.
x
+
dma
.
y
*
dsb
.
y
/
sum_divisor
);
void
block_a_to_registers
(
const
uint
reg_ib
,
const
uint
buf_ib
)
{
cache_a
[
reg_ib
].
dm
=
buf_a
[
buf_ib
].
dm
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
4
;
iqs
++
)
{
cache_a
[
reg_ib
].
qs
[
iqs
]
=
buf_a
[
buf_ib
].
qs
[
iqs
];
}
}
#endif
#if defined(DATA_A_Q5_0)
ACC_TYPE
mmq_dot_product
(
const
uint
ib_a
)
{
int32_t
q_sum
=
0
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
4
;
iqs
++
)
{
const
uint32_t
vui
=
cache_a
[
ib_a
].
qs
[
iqs
];
const
i32vec2
qs_a
=
i32vec2
(
vui
&
0x0F0F0F0F
,
(
vui
>>
4
)
&
0x0F0F0F0F
);
const
int32_t
qs_b0
=
cache_b
.
qs
[
iqs
];
const
int32_t
qs_b1
=
cache_b
.
qs
[
iqs
+
4
];
q_sum
+=
dotPacked4x8EXT
(
qs_a
.
x
,
qs_b0
);
q_sum
+=
dotPacked4x8EXT
(
qs_a
.
y
,
qs_b1
);
}
return
mul_q8_1
(
q_sum
,
cache_a
[
ib_a
].
dm
,
cache_b
.
ds
,
1
);
}
#endif // MMQ_SHMEM
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
// 2-byte loads for Q5_0 blocks (22 bytes)
// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2
repack
(
uint
ib
,
uint
iqs
)
{
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
const
u16vec2
quants
=
u16vec2
(
data_a
[
ib
].
qs
[
iqs
*
2
],
data_a
[
ib
].
qs
[
iqs
*
2
+
1
]);
const
u16vec2
quants
=
u16vec2
(
data_a_packed16
[
ib
].
qs
[
iqs
*
2
],
data_a_packed16
[
ib
].
qs
[
iqs
*
2
+
1
]);
const
uint32_t
vui
=
pack32
(
quants
);
const
int32_t
qh
=
int32_t
((
uint32_t
(
data_a
[
ib
].
qh
[
1
])
<<
16
|
data_a
[
ib
].
qh
[
0
])
>>
(
4
*
iqs
));
#ifdef DATA_A_Q5_0
const
int32_t
qh
=
int32_t
((
uint32_t
(
data_a_packed16
[
ib
].
qh
[
1
])
<<
16
|
data_a_packed16
[
ib
].
qh
[
0
])
>>
(
4
*
iqs
));
#else // DATA_A_Q5_1
const
int32_t
qh
=
int32_t
(
data_a_packed32
[
ib
].
qh
>>
(
4
*
iqs
));
#endif
const
int32_t
v0
=
int32_t
(
vui
&
0x0F0F0F0F
)
|
((
qh
&
0xF
)
*
0x02040810
)
&
0x10101010
;
// (0,1,2,3) -> (4,12,20,28)
...
...
@@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) {
return
i32vec2
(
v0
,
v1
);
}
#ifdef DATA_A_Q5_0
ACC_TYPE
mul_q8_1
(
const
int32_t
q_sum
,
const
float
da
,
const
vec2
dsb
,
const
int32_t
sum_divisor
)
{
return
ACC_TYPE
(
da
*
(
float
(
q_sum
)
*
dsb
.
x
-
(
16
/
sum_divisor
)
*
dsb
.
y
));
}
#else // DATA_A_Q5_1
ACC_TYPE
mul_q8_1
(
const
int32_t
q_sum
,
const
vec2
dma
,
const
vec2
dsb
,
const
int32_t
sum_divisor
)
{
return
ACC_TYPE
(
float
(
q_sum
)
*
dma
.
x
*
dsb
.
x
+
dma
.
y
*
dsb
.
y
/
sum_divisor
);
}
#endif
#if defined(DATA_A_Q5_1)
i32vec2
repack
(
uint
ib
,
uint
iqs
)
{
// Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
const
uint32_t
vui
=
data_a_packed32
[
ib
].
qs
[
iqs
];
const
int32_t
qh
=
int32_t
(
data_a_packed32
[
ib
].
qh
>>
(
4
*
iqs
));
const
int32_t
v0
=
int32_t
(
vui
&
0x0F0F0F0F
)
|
((
qh
&
0xF
)
*
0x02040810
)
&
0x10101010
;
// (0,1,2,3) -> (4,12,20,28)
#ifdef MMQ_SHMEM
void
block_a_to_shmem
(
const
uint
buf_ib
,
const
uint
ib
,
const
uint
iqs
)
{
#ifdef DATA_A_Q5_0
buf_a
[
buf_ib
].
qs
[
iqs
]
=
pack32
(
u16vec2
(
data_a_packed16
[
ib
].
qs
[
iqs
*
2
],
data_a_packed16
[
ib
].
qs
[
iqs
*
2
+
1
]));
const
int32_t
v1
=
int32_t
((
vui
>>
4
)
&
0x0F0F0F0F
)
|
(((
qh
>>
16
)
&
0xF
)
*
0x02040810
)
&
0x10101010
;
// (16,17,18,19) -> (4,12,20,28)
if
(
iqs
==
0
)
{
buf_a
[
buf_ib
].
dm
=
FLOAT_TYPE
(
data_a_packed16
[
ib
].
d
);
buf_a
[
buf_ib
].
qh
=
pack32
(
u16vec2
(
data_a_packed16
[
ib
].
qh
[
0
],
data_a_packed16
[
ib
].
qh
[
1
]));
}
#else // DATA_A_Q5_1
buf_a
[
buf_ib
].
qs
[
iqs
]
=
data_a_packed32
[
ib
].
qs
[
iqs
];
return
i32vec2
(
v0
,
v1
);
if
(
iqs
==
0
)
{
buf_a
[
buf_ib
].
dm
=
FLOAT_TYPE_VEC2
(
data_a_packed32
[
ib
].
dm
);
buf_a
[
buf_ib
].
qh
=
data_a_packed32
[
ib
].
qh
;
}
#endif
}
ACC_TYPE
mul_q8_1
(
const
int32_t
q_sum
,
const
vec2
dma
,
const
vec2
dsb
,
const
int32_t
sum_divisor
)
{
return
ACC_TYPE
(
float
(
q_sum
)
*
dma
.
x
*
dsb
.
x
+
dma
.
y
*
dsb
.
y
/
sum_divisor
);
void
block_a_to_registers
(
const
uint
reg_ib
,
const
uint
buf_ib
)
{
cache_a
[
reg_ib
].
dm
=
buf_a
[
buf_ib
].
dm
;
cache_a
[
reg_ib
].
qh
=
buf_a
[
buf_ib
].
qh
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
4
;
iqs
++
)
{
cache_a
[
reg_ib
].
qs
[
iqs
]
=
buf_a
[
buf_ib
].
qs
[
iqs
];
}
}
ACC_TYPE
mmq_dot_product
(
const
uint
ib_a
)
{
int32_t
q_sum
=
0
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
4
;
iqs
++
)
{
const
uint32_t
vui
=
cache_a
[
ib_a
].
qs
[
iqs
];
const
int32_t
qh
=
int32_t
(
cache_a
[
ib_a
].
qh
>>
(
4
*
iqs
));
const
int32_t
qs_a0
=
int32_t
(
vui
&
0x0F0F0F0F
)
|
((
qh
&
0xF
)
*
0x02040810
)
&
0x10101010
;
// (0,1,2,3) -> (4,12,20,28)
const
int32_t
qs_a1
=
int32_t
((
vui
>>
4
)
&
0x0F0F0F0F
)
|
(((
qh
>>
16
)
&
0xF
)
*
0x02040810
)
&
0x10101010
;
// (16,17,18,19) -> (4,12,20,28)
const
int32_t
qs_b0
=
cache_b
.
qs
[
iqs
];
const
int32_t
qs_b1
=
cache_b
.
qs
[
iqs
+
4
];
q_sum
+=
dotPacked4x8EXT
(
qs_a0
,
qs_b0
);
q_sum
+=
dotPacked4x8EXT
(
qs_a1
,
qs_b1
);
}
return
mul_q8_1
(
q_sum
,
cache_a
[
ib_a
].
dm
,
cache_b
.
ds
,
1
);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t
repack
(
uint
ib
,
uint
iqs
)
{
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
return
pack32
(
i16vec2
(
data_a
[
ib
].
qs
[
iqs
*
2
],
data_a
[
ib
].
qs
[
iqs
*
2
+
1
]));
return
pack32
(
i16vec2
(
data_a_packed16
[
ib
].
qs
[
iqs
*
2
],
data_a_packed16
[
ib
].
qs
[
iqs
*
2
+
1
]));
}
ACC_TYPE
mul_q8_1
(
const
int32_t
q_sum
,
const
float
da
,
const
vec2
dsb
,
const
int32_t
sum_divisor
)
{
return
ACC_TYPE
(
float
(
q_sum
)
*
da
*
dsb
.
x
);
}
#ifdef MMQ_SHMEM
void
block_a_to_shmem
(
const
uint
buf_ib
,
const
uint
ib
,
const
uint
iqs
)
{
buf_a
[
buf_ib
].
qs
[
iqs
]
=
pack32
(
i16vec2
(
data_a_packed16
[
ib
].
qs
[
iqs
*
2
],
data_a_packed16
[
ib
].
qs
[
iqs
*
2
+
1
]));
if
(
iqs
==
0
)
{
buf_a
[
buf_ib
].
dm
=
FLOAT_TYPE
(
data_a_packed16
[
ib
].
d
);
}
}
void
block_a_to_registers
(
const
uint
reg_ib
,
const
uint
buf_ib
)
{
cache_a
[
reg_ib
].
dm
=
buf_a
[
buf_ib
].
dm
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
8
;
iqs
++
)
{
cache_a
[
reg_ib
].
qs
[
iqs
]
=
buf_a
[
buf_ib
].
qs
[
iqs
];
}
}
ACC_TYPE
mmq_dot_product
(
const
uint
ib_a
)
{
int32_t
q_sum
=
0
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
8
;
iqs
++
)
{
const
int32_t
qs_a
=
cache_a
[
ib_a
].
qs
[
iqs
];
const
int32_t
qs_b
=
cache_b
.
qs
[
iqs
];
q_sum
+=
dotPacked4x8EXT
(
qs_a
,
qs_b
);
}
return
mul_q8_1
(
q_sum
,
cache_a
[
ib_a
].
dm
,
cache_b
.
ds
,
1
);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2
repack
(
uint
ib
,
uint
iqs
)
{
const
uint32_t
quants
=
pack32
(
u8vec4
(
data_a
[
ib
].
qs
[
iqs
*
4
],
data_a
[
ib
].
qs
[
iqs
*
4
+
1
],
data_a
[
ib
].
qs
[
iqs
*
4
+
2
],
data_a
[
ib
].
qs
[
iqs
*
4
+
3
]));
return
i32vec2
(
quants
&
0x0F0F0F0F
,
(
quants
>>
4
)
&
0x0F0F0F0F
);
}
ACC_TYPE
mul_q8_1
(
const
int32_t
q_sum
,
const
float
da
,
const
vec2
dsb
,
const
int32_t
sum_divisor
)
{
return
ACC_TYPE
(
da
*
dsb
.
x
*
float
(
q_sum
));
}
#ifdef MMQ_SHMEM
void
block_a_to_shmem
(
const
uint
buf_ib
,
const
uint
ib
,
const
uint
iqs
)
{
const
uint32_t
qs
=
pack32
(
u8vec4
(
data_a
[
ib
].
qs
[
iqs
*
4
],
data_a
[
ib
].
qs
[
iqs
*
4
+
1
],
data_a
[
ib
].
qs
[
iqs
*
4
+
2
],
data_a
[
ib
].
qs
[
iqs
*
4
+
3
]));
const
u8vec4
i_a0
=
unpack8
(
qs
&
0x0F0F0F0F
);
const
u8vec4
i_a1
=
unpack8
((
qs
>>
4
)
&
0x0F0F0F0F
);
buf_a
[
buf_ib
].
qs
[
iqs
]
=
pack32
(
i8vec4
(
kvalues_mxfp4
[
i_a0
.
x
],
kvalues_mxfp4
[
i_a0
.
y
],
kvalues_mxfp4
[
i_a0
.
z
],
kvalues_mxfp4
[
i_a0
.
w
]));
buf_a
[
buf_ib
].
qs
[
iqs
+
4
]
=
pack32
(
i8vec4
(
kvalues_mxfp4
[
i_a1
.
x
],
kvalues_mxfp4
[
i_a1
.
y
],
kvalues_mxfp4
[
i_a1
.
z
],
kvalues_mxfp4
[
i_a1
.
w
]));
if
(
iqs
==
0
)
{
buf_a
[
buf_ib
].
d
=
FLOAT_TYPE
(
e8m0_to_fp32
(
data_a
[
ib
].
e
)
*
0
.
5
);
}
}
void
block_a_to_registers
(
const
uint
reg_ib
,
const
uint
buf_ib
)
{
cache_a
[
reg_ib
].
d
=
buf_a
[
buf_ib
].
d
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
8
;
iqs
++
)
{
cache_a
[
reg_ib
].
qs
[
iqs
]
=
buf_a
[
buf_ib
].
qs
[
iqs
];
}
}
ACC_TYPE
mmq_dot_product
(
const
uint
ib_a
)
{
int32_t
q_sum
=
0
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
8
;
iqs
++
)
{
const
int32_t
qs_a
=
cache_a
[
ib_a
].
qs
[
iqs
];
q_sum
+=
dotPacked4x8EXT
(
qs_a
,
cache_b
.
qs
[
iqs
]);
}
return
mul_q8_1
(
q_sum
,
cache_a
[
ib_a
].
d
,
cache_b
.
ds
,
1
);
}
#endif // MMQ_SHMEM
#endif
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
#if defined(DATA_A_Q2_K)
// 4-byte loads for Q2_K blocks (84 bytes)
int32_t
repack
(
uint
ib
,
uint
iqs
)
{
const
uint
ib_k
=
ib
/
8
;
const
uint
iqs_k
=
(
ib
%
8
)
*
8
+
iqs
;
const
uint
qs_idx
=
(
iqs_k
/
32
)
*
8
+
(
iqs_k
%
8
);
const
uint
qs_shift
=
((
iqs_k
%
32
)
/
8
)
*
2
;
return
int32_t
((
data_a_packed32
[
ib_k
].
qs
[
qs_idx
]
>>
qs_shift
)
&
0x03030303
);
}
uint8_t
get_scale
(
uint
ib
,
uint
iqs
)
{
const
uint
ib_k
=
ib
/
8
;
const
uint
iqs_k
=
(
ib
%
8
)
*
8
+
iqs
;
return
data_a
[
ib_k
].
scales
[
iqs_k
/
4
];
}
ACC_TYPE
mul_q8_1
(
const
int32_t
sum_d
,
const
int32_t
sum_m
,
const
vec2
dma
,
const
vec2
dsb
,
const
int32_t
sum_divisor
)
{
return
ACC_TYPE
(
dsb
.
x
*
(
dma
.
x
*
float
(
sum_d
)
-
dma
.
y
*
float
(
sum_m
)));
}
#ifdef MMQ_SHMEM
void
block_a_to_shmem
(
const
uint
buf_ib
,
const
uint
ib
,
const
uint
iqs
)
{
const
uint
ib_k
=
ib
/
8
;
const
uint
iqs_k
=
(
ib
%
8
)
*
8
+
iqs
*
QUANT_R_MMQ
;
const
uint
qs_idx
=
(
iqs_k
/
32
)
*
8
+
(
iqs_k
%
8
);
const
uint
qs_shift
=
((
iqs_k
%
32
)
/
8
)
*
2
;
// Repack 4x4 quants into one int
const
uint32_t
vals0
=
(
data_a_packed32
[
ib_k
].
qs
[
qs_idx
]
>>
qs_shift
)
&
0x03030303
;
const
uint32_t
vals1
=
(
data_a_packed32
[
ib_k
].
qs
[
qs_idx
+
1
]
>>
qs_shift
)
&
0x03030303
;
const
uint32_t
vals2
=
(
data_a_packed32
[
ib_k
].
qs
[
qs_idx
+
2
]
>>
qs_shift
)
&
0x03030303
;
const
uint32_t
vals3
=
(
data_a_packed32
[
ib_k
].
qs
[
qs_idx
+
3
]
>>
qs_shift
)
&
0x03030303
;
buf_a
[
buf_ib
].
qs
[
iqs
]
=
vals0
|
(
vals1
<<
2
)
|
(
vals2
<<
4
)
|
(
vals3
<<
6
);
if
(
iqs
==
0
)
{
buf_a
[
buf_ib
].
dm
=
FLOAT_TYPE_VEC2
(
data_a_packed32
[
ib_k
].
dm
);
buf_a
[
buf_ib
].
scales
=
unpack8
(
data_a_packed16
[
ib_k
].
scales
[
iqs_k
/
8
]);
}
}
void
block_a_to_registers
(
const
uint
reg_ib
,
const
uint
buf_ib
)
{
cache_a
[
reg_ib
].
dm
=
buf_a
[
buf_ib
].
dm
;
cache_a
[
reg_ib
].
scales
=
buf_a
[
buf_ib
].
scales
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
2
;
iqs
++
)
{
cache_a
[
reg_ib
].
qs
[
iqs
]
=
buf_a
[
buf_ib
].
qs
[
iqs
];
}
}
ACC_TYPE
mmq_dot_product
(
const
uint
ib_a
)
{
int32_t
sum_d
=
0
;
int32_t
sum_m
=
0
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
8
;
iqs
++
)
{
const
uint8_t
scale
=
cache_a
[
ib_a
].
scales
[
iqs
/
4
];
const
int32_t
scale_m
=
int32_t
(
scale
>>
4
)
*
0x01010101
;
// Duplicate 8-bit value across 32-bits.
const
int32_t
qs_a
=
int32_t
((
cache_a
[
ib_a
].
qs
[
iqs
/
4
]
>>
((
iqs
%
4
)
*
2
))
&
0x03030303
);
sum_d
+=
dotPacked4x8EXT
(
qs_a
,
cache_b
.
qs
[
iqs
])
*
(
scale
&
0xF
);
sum_m
+=
dotPacked4x8EXT
(
scale_m
,
cache_b
.
qs
[
iqs
]);
}
return
mul_q8_1
(
sum_d
,
sum_m
,
cache_a
[
ib_a
].
dm
,
cache_b
.
ds
,
1
);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q3_K)
// 2-byte loads for Q3_K blocks (110 bytes)
#ifdef MMQ_SHMEM
void
block_a_to_shmem
(
const
uint
buf_ib
,
const
uint
ib
,
const
uint
iqs
)
{
const
uint
ib_k
=
ib
/
8
;
const
uint
hm_idx
=
iqs
*
QUANT_R_MMQ
;
const
uint
iqs_k
=
(
ib
%
8
)
*
8
+
hm_idx
;
const
uint
qs_idx
=
(
iqs_k
/
32
)
*
8
+
(
iqs_k
%
8
);
const
uint
qs_shift
=
((
iqs_k
%
32
)
/
8
)
*
2
;
const
uint
hm_shift
=
iqs_k
/
8
;
// Repack 2x4 quants into one int
// Add the 3rd bit instead of subtracting it to allow packing the quants
const
i8vec2
vals00
=
unpack8
(
int16_t
((
data_a_packed16
[
ib_k
].
qs
[
qs_idx
*
2
]
>>
qs_shift
)
&
uint16_t
(
0x0303
)))
|
unpack8
(
int16_t
(((
data_a_packed16
[
ib_k
].
hmask
[
hm_idx
*
2
]
>>
hm_shift
)
&
uint16_t
(
0x0101
))
<<
2
));
const
i8vec2
vals01
=
unpack8
(
int16_t
((
data_a_packed16
[
ib_k
].
qs
[
qs_idx
*
2
+
1
]
>>
qs_shift
)
&
uint16_t
(
0x0303
)))
|
unpack8
(
int16_t
(((
data_a_packed16
[
ib_k
].
hmask
[
hm_idx
*
2
+
1
]
>>
hm_shift
)
&
uint16_t
(
0x0101
))
<<
2
));
const
i8vec2
vals10
=
unpack8
(
int16_t
((
data_a_packed16
[
ib_k
].
qs
[
qs_idx
*
2
+
2
]
>>
qs_shift
)
&
uint16_t
(
0x0303
)))
|
unpack8
(
int16_t
(((
data_a_packed16
[
ib_k
].
hmask
[
hm_idx
*
2
+
2
]
>>
hm_shift
)
&
uint16_t
(
0x0101
))
<<
2
));
const
i8vec2
vals11
=
unpack8
(
int16_t
((
data_a_packed16
[
ib_k
].
qs
[
qs_idx
*
2
+
3
]
>>
qs_shift
)
&
uint16_t
(
0x0303
)))
|
unpack8
(
int16_t
(((
data_a_packed16
[
ib_k
].
hmask
[
hm_idx
*
2
+
3
]
>>
hm_shift
)
&
uint16_t
(
0x0101
))
<<
2
));
buf_a
[
buf_ib
].
qs
[
iqs
]
=
pack32
(
u8vec4
(
vals00
.
x
,
vals00
.
y
,
vals01
.
x
,
vals01
.
y
))
|
(
pack32
(
u8vec4
(
vals10
.
x
,
vals10
.
y
,
vals11
.
x
,
vals11
.
y
))
<<
4
);
if
(
iqs
==
0
)
{
const
uint
is
=
iqs_k
/
4
;
const
i8vec2
scales
=
i8vec2
(
unpack8
(((
data_a_packed16
[
ib_k
].
scales
[(
is
%
8
)
/
2
]
>>
(
4
*
(
is
/
8
)))
&
0x0F0F
)
|
(((
data_a_packed16
[
ib_k
].
scales
[(
8
+
(
is
%
4
))
/
2
]
>>
(
2
*
(
is
/
4
)))
&
0x0303
)
<<
4
)));
buf_a
[
buf_ib
].
d_scales
=
FLOAT_TYPE
(
data_a_packed16
[
ib_k
].
d
)
*
FLOAT_TYPE_VEC2
(
scales
-
32
);
}
}
void
block_a_to_registers
(
const
uint
reg_ib
,
const
uint
buf_ib
)
{
cache_a
[
reg_ib
].
d_scales
=
buf_a
[
buf_ib
].
d_scales
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
4
;
iqs
++
)
{
cache_a
[
reg_ib
].
qs
[
iqs
]
=
buf_a
[
buf_ib
].
qs
[
iqs
];
}
}
ACC_TYPE
mmq_dot_product
(
const
uint
ib_a
)
{
float
result
=
0
.
0
;
int32_t
q_sum
=
0
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
4
;
iqs
++
)
{
// Subtract 4 from the quants to correct the 3rd bit offset
const
int32_t
qs_a
=
pack32
(
unpack8
(
int32_t
((
cache_a
[
ib_a
].
qs
[
iqs
/
2
]
>>
((
iqs
%
2
)
*
4
))
&
0x0F0F0F0F
))
-
int8_t
(
4
));
q_sum
+=
dotPacked4x8EXT
(
qs_a
,
cache_b
.
qs
[
iqs
]);
}
result
+=
float
(
cache_a
[
ib_a
].
d_scales
[
0
])
*
float
(
q_sum
);
q_sum
=
0
;
[[
unroll
]]
for
(
uint
iqs
=
4
;
iqs
<
8
;
iqs
++
)
{
const
int32_t
qs_a
=
pack32
(
unpack8
(
int32_t
((
cache_a
[
ib_a
].
qs
[
iqs
/
2
]
>>
((
iqs
%
2
)
*
4
))
&
0x0F0F0F0F
))
-
int8_t
(
4
));
q_sum
+=
dotPacked4x8EXT
(
qs_a
,
cache_b
.
qs
[
iqs
]);
}
result
+=
float
(
cache_a
[
ib_a
].
d_scales
[
1
])
*
float
(
q_sum
);
return
ACC_TYPE
(
cache_b
.
ds
.
x
*
result
);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
ACC_TYPE
mul_q8_1
(
const
int32_t
q_sum
,
const
vec2
dma
,
const
vec2
dsb
,
const
int32_t
sum_divisor
)
{
return
ACC_TYPE
(
dsb
.
x
*
dma
.
x
*
float
(
q_sum
)
-
dma
.
y
*
dsb
.
y
);
}
#ifdef MMQ_SHMEM
void
block_a_to_shmem
(
const
uint
buf_ib
,
const
uint
ib
,
const
uint
iqs
)
{
const
uint
ib_k
=
ib
/
8
;
const
uint
iqs_k
=
(
ib
%
8
)
*
8
+
iqs
*
QUANT_R_MMQ
;
const
uint
qs_idx
=
(
iqs_k
/
16
)
*
8
+
(
iqs_k
%
8
);
const
uint
qs_shift
=
((
iqs_k
%
16
)
/
8
)
*
4
;
// Repack 2x4 quants into one int
#if defined(DATA_A_Q4_K)
const
uint32_t
vals0
=
(
data_a_packed32
[
ib_k
].
qs
[
qs_idx
]
>>
qs_shift
)
&
0x0F0F0F0F
;
const
uint32_t
vals1
=
(
data_a_packed32
[
ib_k
].
qs
[
qs_idx
+
1
]
>>
qs_shift
)
&
0x0F0F0F0F
;
buf_a
[
buf_ib
].
qs
[
iqs
]
=
vals0
|
(
vals1
<<
4
);
#else // defined(DATA_A_Q5_K)
const
uint
qh_idx
=
iqs
*
QUANT_R_MMQ
;
const
uint
qh_shift
=
iqs_k
/
8
;
buf_a
[
buf_ib
].
qs
[
iqs
]
=
int32_t
(((
data_a_packed32
[
ib_k
].
qs
[
qs_idx
]
>>
qs_shift
)
&
0x0F0F0F0F
)
|
(((
data_a_packed32
[
ib_k
].
qh
[
qh_idx
]
>>
qh_shift
)
&
0x01010101
)
<<
4
));
#endif
if
(
iqs
==
0
)
{
// Scale index
const
uint
is
=
iqs_k
/
8
;
u8vec2
scale_dm
;
if
(
is
<
4
)
{
scale_dm
=
u8vec2
(
data_a
[
ib_k
].
scales
[
is
]
&
0x3F
,
data_a
[
ib_k
].
scales
[
is
+
4
]
&
0x3F
);
}
else
{
scale_dm
=
u8vec2
((
data_a
[
ib_k
].
scales
[
is
+
4
]
&
0xF
)
|
((
data_a
[
ib_k
].
scales
[
is
-
4
]
&
0xC0
)
>>
2
),
(
data_a
[
ib_k
].
scales
[
is
+
4
]
>>
4
)
|
((
data_a
[
ib_k
].
scales
[
is
]
&
0xC0
)
>>
2
));
}
buf_a
[
buf_ib
].
dm
=
FLOAT_TYPE_VEC2
(
data_a_packed32
[
ib_k
].
dm
)
*
FLOAT_TYPE_VEC2
(
scale_dm
);
}
}
void
block_a_to_registers
(
const
uint
reg_ib
,
const
uint
buf_ib
)
{
cache_a
[
reg_ib
].
dm
=
buf_a
[
buf_ib
].
dm
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
8
/
QUANT_R_MMQ
;
iqs
++
)
{
cache_a
[
reg_ib
].
qs
[
iqs
]
=
buf_a
[
buf_ib
].
qs
[
iqs
];
}
}
ACC_TYPE
mmq_dot_product
(
const
uint
ib_a
)
{
int32_t
q_sum
=
0
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
8
;
iqs
++
)
{
#if defined(DATA_A_Q4_K)
const
int32_t
qs_a
=
int32_t
((
cache_a
[
ib_a
].
qs
[
iqs
/
2
]
>>
((
iqs
%
2
)
*
4
))
&
0x0F0F0F0F
);
#else // defined(DATA_A_Q5_K)
const
int32_t
qs_a
=
cache_a
[
ib_a
].
qs
[
iqs
];
#endif
q_sum
+=
dotPacked4x8EXT
(
qs_a
,
cache_b
.
qs
[
iqs
]);
}
return
mul_q8_1
(
q_sum
,
cache_a
[
ib_a
].
dm
,
cache_b
.
ds
,
1
);
}
#endif // MMQ_SHMEM
#endif
#ifdef MMQ_SHMEM
void
block_b_to_shmem
(
const
uint
buf_ib
,
const
uint
ib
,
const
uint
iqs
)
{
const
uint
ib_outer
=
ib
/
4
;
const
uint
ib_inner
=
ib
%
4
;
if
(
iqs
==
0
)
{
buf_b
[
buf_ib
].
ds
=
FLOAT_TYPE_VEC2
(
data_b
[
ib_outer
].
ds
[
ib_inner
]);
}
const
ivec4
values
=
data_b
[
ib_outer
].
qs
[
ib_inner
*
2
+
iqs
];
buf_b
[
buf_ib
].
qs
[
iqs
*
4
]
=
values
.
x
;
buf_b
[
buf_ib
].
qs
[
iqs
*
4
+
1
]
=
values
.
y
;
buf_b
[
buf_ib
].
qs
[
iqs
*
4
+
2
]
=
values
.
z
;
buf_b
[
buf_ib
].
qs
[
iqs
*
4
+
3
]
=
values
.
w
;
}
void
block_b_to_registers
(
const
uint
ib
)
{
cache_b
.
ds
=
buf_b
[
ib
].
ds
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
BK
/
4
;
iqs
++
)
{
cache_b
.
qs
[
iqs
]
=
buf_b
[
ib
].
qs
[
iqs
];
}
}
#endif
#if defined(DATA_A_Q6_K)
// 2-byte loads for Q6_K blocks (210 bytes)
#ifdef MMQ_SHMEM
void
block_a_to_shmem
(
const
uint
buf_ib
,
const
uint
ib
,
const
uint
iqs
)
{
const
uint
ib_k
=
ib
/
8
;
const
uint
iqs_k
=
(
ib
%
8
)
*
8
+
iqs
;
const
uint
ql_idx
=
(
iqs_k
/
32
)
*
16
+
iqs_k
%
16
;
const
uint
ql_shift
=
((
iqs_k
%
32
)
/
16
)
*
4
;
const
uint
qh_idx
=
(
iqs_k
/
32
)
*
8
+
iqs
;
const
uint
qh_shift
=
((
iqs_k
%
32
)
/
8
)
*
2
;
const
i8vec2
vals00
=
(
unpack8
(
int16_t
((
data_a_packed16
[
ib_k
].
ql
[
ql_idx
*
2
]
>>
ql_shift
)
&
uint16_t
(
0x0F0F
)))
|
unpack8
(
int16_t
(((
data_a_packed16
[
ib_k
].
qh
[
qh_idx
*
2
]
>>
qh_shift
)
&
uint16_t
(
0x0303
))
<<
4
)))
-
int8_t
(
32
);
const
i8vec2
vals01
=
(
unpack8
(
int16_t
((
data_a_packed16
[
ib_k
].
ql
[
ql_idx
*
2
+
1
]
>>
ql_shift
)
&
uint16_t
(
0x0F0F
)))
|
unpack8
(
int16_t
(((
data_a_packed16
[
ib_k
].
qh
[
qh_idx
*
2
+
1
]
>>
qh_shift
)
&
uint16_t
(
0x0303
))
<<
4
)))
-
int8_t
(
32
);
buf_a
[
buf_ib
].
qs
[
iqs
]
=
pack32
(
i8vec4
(
vals00
.
x
,
vals00
.
y
,
vals01
.
x
,
vals01
.
y
));
if
(
iqs
==
0
)
{
const
uint
is
=
iqs_k
/
4
;
const
i8vec2
scales
=
unpack8
(
data_a_packed16
[
ib_k
].
scales
[
is
/
2
]);
buf_a
[
buf_ib
].
d_scales
=
FLOAT_TYPE
(
data_a_packed16
[
ib_k
].
d
)
*
FLOAT_TYPE_VEC2
(
scales
);
}
}
void
block_a_to_registers
(
const
uint
reg_ib
,
const
uint
buf_ib
)
{
cache_a
[
reg_ib
].
d_scales
=
buf_a
[
buf_ib
].
d_scales
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
8
;
iqs
++
)
{
cache_a
[
reg_ib
].
qs
[
iqs
]
=
buf_a
[
buf_ib
].
qs
[
iqs
];
}
}
ACC_TYPE
mmq_dot_product
(
const
uint
ib_a
)
{
float
result
=
0
.
0
;
int32_t
q_sum
=
0
;
[[
unroll
]]
for
(
uint
iqs
=
0
;
iqs
<
4
;
iqs
++
)
{
const
int32_t
qs_a
=
cache_a
[
ib_a
].
qs
[
iqs
];
q_sum
+=
dotPacked4x8EXT
(
qs_a
,
cache_b
.
qs
[
iqs
]);
}
result
+=
float
(
cache_a
[
ib_a
].
d_scales
[
0
])
*
float
(
q_sum
);
q_sum
=
0
;
[[
unroll
]]
for
(
uint
iqs
=
4
;
iqs
<
8
;
iqs
++
)
{
const
int32_t
qs_a
=
cache_a
[
ib_a
].
qs
[
iqs
];
q_sum
+=
dotPacked4x8EXT
(
qs_a
,
cache_b
.
qs
[
iqs
]);
}
result
+=
float
(
cache_a
[
ib_a
].
d_scales
[
1
])
*
float
(
q_sum
);
return
ACC_TYPE
(
cache_b
.
ds
.
x
*
result
);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
...
...
@@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) {
return
FLOAT_TYPE_VEC2
(
data_a_packed32
[
ib
].
dm
);
}
#endif
#if defined(DATA_A_Q2_K)
FLOAT_TYPE_VEC2
get_dm
(
uint
ib
)
{
const
uint
ib_k
=
ib
/
8
;
return
FLOAT_TYPE_VEC2
(
data_a_packed32
[
ib_k
].
dm
);
}
#endif
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
0 → 100644
View file @
3a9e8e9f
#if defined(DATA_A_Q4_0)
#define QUANT_R_MMQ 2
struct
block_a_cache
{
uint32_t
qs
[
16
/
4
];
FLOAT_TYPE
dm
;
};
#elif defined(DATA_A_Q4_1)
#define QUANT_R_MMQ 2
struct
block_a_cache
{
uint32_t
qs
[
16
/
4
];
FLOAT_TYPE_VEC2
dm
;
};
#elif defined(DATA_A_Q5_0)
#define QUANT_R_MMQ 2
struct
block_a_cache
{
uint32_t
qs
[
16
/
4
];
uint32_t
qh
;
FLOAT_TYPE
dm
;
};
#elif defined(DATA_A_Q5_1)
#define QUANT_R_MMQ 2
struct
block_a_cache
{
uint32_t
qs
[
16
/
4
];
uint32_t
qh
;
FLOAT_TYPE_VEC2
dm
;
};
#elif defined(DATA_A_Q8_0)
#define QUANT_R_MMQ 1
// AMD likes 4, Intel likes 1 and Nvidia likes 2
// #define BK_STEP 1
struct
block_a_cache
{
int32_t
qs
[
32
/
4
];
FLOAT_TYPE
dm
;
};
#elif defined(DATA_A_MXFP4)
#define QUANT_R_MMQ 2
struct
block_a_cache
{
int32_t
qs
[
8
];
FLOAT_TYPE
d
;
};
#elif defined(DATA_A_Q2_K)
#define QUANT_R_MMQ 4
struct
block_a_cache
{
uint32_t
qs
[
2
];
u8vec2
scales
;
FLOAT_TYPE_VEC2
dm
;
};
#elif defined(DATA_A_Q3_K)
#define QUANT_R_MMQ 2
struct
block_a_cache
{
uint32_t
qs
[
4
];
FLOAT_TYPE_VEC2
d_scales
;
};
#elif defined(DATA_A_Q4_K)
#define QUANT_R_MMQ 2
struct
block_a_cache
{
uint32_t
qs
[
4
];
FLOAT_TYPE_VEC2
dm
;
};
#elif defined(DATA_A_Q5_K)
#define QUANT_R_MMQ 1
struct
block_a_cache
{
int32_t
qs
[
8
];
FLOAT_TYPE_VEC2
dm
;
};
#elif defined(DATA_A_Q6_K)
#define QUANT_R_MMQ 1
struct
block_a_cache
{
int32_t
qs
[
8
];
FLOAT_TYPE_VEC2
d_scales
;
};
#endif
struct
block_b_cache
{
int32_t
qs
[
8
];
FLOAT_TYPE_VEC2
ds
;
};
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
View file @
3a9e8e9f
...
...
@@ -10,6 +10,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout
(
binding
=
1
)
readonly
buffer
Y
{
int
data_pos
[];};
layout
(
binding
=
2
)
readonly
buffer
Z
{
float
data_ff
[];};
layout
(
binding
=
3
)
writeonly
buffer
D
{
D_TYPE
data_d
[];};
layout
(
binding
=
4
)
readonly
buffer
I
{
uvec2
data_i
[];};
// indices for set_rows
layout
(
push_constant
)
uniform
parameter
{
uint
ncols
;
...
...
@@ -27,6 +28,7 @@ layout (push_constant) uniform parameter {
uint
s2
;
int
sections
[
4
];
uint
is_back
;
uint
set_rows_stride
;
}
p
;
float
rope_yarn_ramp
(
const
float
low
,
const
float
high
,
const
uint
i0
)
{
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
View file @
3a9e8e9f
...
...
@@ -16,12 +16,19 @@ void main() {
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const
uint idst = row_dst*ne0 + i0/2;
uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
// Fusion optimization: ROPE + VIEW + SET_ROWS..
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
if (p.set_rows_stride != 0) {
idst = row_x*ne0 + i0/2;
idst += data_i[channel_x].x * p.set_rows_stride;
}
if (i0 >= p.n_dims) {
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
data_d[idst + i0/2 + 0] =
D_TYPE(
data_a[ix + i0/2 + 0]
)
;
data_d[idst + i0/2 + 1] =
D_TYPE(
data_a[ix + i0/2 + 1]
)
;
return;
}
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
View file @
3a9e8e9f
...
...
@@ -16,12 +16,19 @@ void main() {
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const
uint idst = row_dst*ne0 + i0;
uint idst = row_dst*ne0 + i0;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
// Fusion optimization: ROPE + VIEW + SET_ROWS..
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
if (p.set_rows_stride != 0) {
idst = row_x*ne0 + i0;
idst += data_i[channel_x].x * p.set_rows_stride;
}
if (i0 >= p.n_dims) {
data_d[idst + 0] = data_a[ix + 0];
data_d[idst + 1] = data_a[ix + 1];
data_d[idst + 0] =
D_TYPE(
data_a[ix + 0]
)
;
data_d[idst + 1] =
D_TYPE(
data_a[ix + 1]
)
;
return;
}
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
View file @
3a9e8e9f
...
...
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
float clamp_min;
float clamp_max;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
...
...
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 2) const bool with_norm = true;
layout(constant_id = 3) const bool late_softmax = false;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
...
...
@@ -25,53 +28,72 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
void main() {
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
if (row >= n_rows) {
return;
}
const float INFINITY = 1.0 / 0.0;
const uint logits_offset = n_experts * row;
const uint weights_offset = n_expert_used * row;
const uint ids_offset = n_experts * row;
float logits_r[experts_per_thread];
const float INFINITY = 1.0 / 0.0;
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
float max_val = -INFINITY;
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + gl_LocalInvocationID.x;
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
for (int i = 0; i < experts_per_thread; i++) {
const uint idx = lane + i * WARP_SIZE;
const bool is_active = !use_limit || (idx < limit);
if (is_active) {
max_val = max(max_val, vals[i]);
}
}
float max_val = logits_r[0];
max_val = subgroupMax(max_val);
float sum = 0.f;
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const float val = logits_r[i];
max_val = max(val, max_val);
for (int i = 0; i < experts_per_thread; i++) {
const uint idx = lane + i * WARP_SIZE;
const bool is_active = !use_limit || (idx < limit);
if (is_active) {
const float val = exp(vals[i] - max_val);
vals[i] = val;
sum += val;
} else {
vals[i] = 0.f;
}
}
max_val
= subgroup
Max(max_val
);
sum
= subgroup
Add(sum
);
float wt[experts_per_thread];
float tmp = 0.f;
const float inv_sum = 1.0f / sum;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const float val = logits_r[i];
wt[i] = exp(val - max_val);
tmp += wt[i];
const uint idx = lane + i * WARP_SIZE;
const bool is_active = !use_limit || (idx < limit);
if (is_active) {
vals[i] *= inv_sum;
}
}
}
tmp = subgroupAdd(tmp);
void main() {
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
if (row >= n_rows) {
return;
}
const float inv_sum = 1.0f / tmp;
const uint logits_offset = n_experts * row;
const uint weights_offset = n_expert_used * row;
const uint ids_offset = n_experts * row;
float wt[experts_per_thread];
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = wt[i] * inv_sum;
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + gl_LocalInvocationID.x;
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
}
if (!late_softmax) {
softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
}
// at this point, each thread holds a portion of softmax,
...
...
@@ -82,6 +104,11 @@ void main() {
float output_weights[experts_per_thread];
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
output_weights[i] = 0.f;
}
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = gl_LocalInvocationID.x;
...
...
@@ -121,6 +148,7 @@ void main() {
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
wt_sum = clamp(wt_sum, clamp_min, clamp_max);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
...
...
@@ -129,6 +157,10 @@ void main() {
}
}
if (late_softmax) {
softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
}
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
View file @
3a9e8e9f
...
...
@@ -66,6 +66,7 @@ struct block_q4_0_packed16
#define QUANT_AUXF 1
#define A_TYPE block_q4_0
#define A_TYPE_PACKED16 block_q4_0_packed16
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q4_1 32
...
...
@@ -98,6 +99,7 @@ struct block_q4_1_packed32
#define A_TYPE block_q4_1
#define A_TYPE_PACKED16 block_q4_1_packed16
#define A_TYPE_PACKED32 block_q4_1_packed32
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_0 32
...
...
@@ -123,6 +125,7 @@ struct block_q5_0_packed16
#define QUANT_AUXF 1
#define A_TYPE block_q5_0
#define A_TYPE_PACKED16 block_q5_0_packed16
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_1 32
...
...
@@ -158,6 +161,7 @@ struct block_q5_1_packed32
#define A_TYPE block_q5_1
#define A_TYPE_PACKED16 block_q5_1_packed16
#define A_TYPE_PACKED32 block_q5_1_packed32
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_0 32
...
...
@@ -186,6 +190,7 @@ struct block_q8_0_packed32
#define A_TYPE block_q8_0
#define A_TYPE_PACKED16 block_q8_0_packed16
#define A_TYPE_PACKED32 block_q8_0_packed32
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_1 32
...
...
@@ -226,21 +231,21 @@ struct block_q2_K
{
uint8_t
scales
[
QUANT_K_Q2_K
/
16
];
uint8_t
qs
[
QUANT_K_Q2_K
/
4
];
f16vec2
d
;
f16vec2
d
m
;
};
struct
block_q2_K_packed16
{
uint16_t
scales
[
QUANT_K_Q2_K
/
16
/
2
];
uint16_t
qs
[
QUANT_K_Q2_K
/
4
/
2
];
f16vec2
d
;
f16vec2
d
m
;
};
struct
block_q2_K_packed32
{
uint32_t
scales
[
QUANT_K_Q2_K
/
16
/
4
];
uint32_t
qs
[
QUANT_K_Q2_K
/
4
/
4
];
f16vec2
d
;
f16vec2
d
m
;
};
#if defined(DATA_A_Q2_K)
...
...
@@ -249,6 +254,8 @@ struct block_q2_K_packed32
#define A_TYPE block_q2_K
#define A_TYPE_PACKED16 block_q2_K_packed16
#define A_TYPE_PACKED32 block_q2_K_packed32
#define SCALES_PER_32 2
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q3_K 256
...
...
@@ -274,27 +281,28 @@ struct block_q3_K_packed16
#define QUANT_R 1
#define A_TYPE block_q3_K
#define A_TYPE_PACKED16 block_q3_K_packed16
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q4_K 256
struct
block_q4_K
{
f16vec2
d
;
f16vec2
d
m
;
uint8_t
scales
[
3
*
QUANT_K_Q4_K
/
64
];
uint8_t
qs
[
QUANT_K_Q4_K
/
2
];
};
struct
block_q4_K_packed16
{
f16vec2
d
;
f16vec2
d
m
;
uint16_t
scales
[
3
*
QUANT_K_Q4_K
/
64
/
2
];
uint16_t
qs
[
QUANT_K_Q4_K
/
2
/
2
];
};
struct
block_q4_K_packed32
{
f16vec2
d
;
f16vec2
d
m
;
uint32_t
scales
[
3
*
QUANT_K_Q4_K
/
64
/
4
];
uint32_t
qs
[
QUANT_K_Q4_K
/
2
/
4
];
};
...
...
@@ -310,13 +318,14 @@ struct block_q4_K_packed128
#define A_TYPE block_q4_K
#define A_TYPE_PACKED16 block_q4_K_packed16
#define A_TYPE_PACKED32 block_q4_K_packed32
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q5_K 256
struct
block_q5_K
{
f16vec2
d
;
f16vec2
d
m
;
uint8_t
scales
[
12
];
uint8_t
qh
[
QUANT_K_Q5_K
/
8
];
uint8_t
qs
[
QUANT_K_Q5_K
/
2
];
...
...
@@ -324,12 +333,20 @@ struct block_q5_K
struct
block_q5_K_packed16
{
f16vec2
d
;
f16vec2
d
m
;
uint16_t
scales
[
12
/
2
];
uint16_t
qh
[
QUANT_K_Q5_K
/
8
/
2
];
uint16_t
qs
[
QUANT_K_Q5_K
/
2
/
2
];
};
struct
block_q5_K_packed32
{
f16vec2
dm
;
uint32_t
scales
[
12
/
4
];
uint32_t
qh
[
QUANT_K_Q5_K
/
8
/
4
];
uint32_t
qs
[
QUANT_K_Q5_K
/
2
/
4
];
};
struct
block_q5_K_packed128
{
uvec4
q5k
[
11
];
...
...
@@ -340,6 +357,8 @@ struct block_q5_K_packed128
#define QUANT_R 1
#define A_TYPE block_q5_K
#define A_TYPE_PACKED16 block_q5_K_packed16
#define A_TYPE_PACKED32 block_q5_K_packed32
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q6_K 256
...
...
@@ -356,7 +375,7 @@ struct block_q6_K_packed16
{
uint16_t
ql
[
QUANT_K_Q6_K
/
2
/
2
];
uint16_t
qh
[
QUANT_K_Q6_K
/
4
/
2
];
int
8
_t
scales
[
QUANT_K_Q6_K
/
16
];
int
16
_t
scales
[
QUANT_K_Q6_K
/
16
/
2
];
float16_t
d
;
};
...
...
@@ -365,6 +384,7 @@ struct block_q6_K_packed16
#define QUANT_R 1
#define A_TYPE block_q6_K
#define A_TYPE_PACKED16 block_q6_K_packed16
#define DATA_A_QUANT_K
#endif
// IQuants
...
...
@@ -1363,18 +1383,11 @@ struct block_mxfp4
uint8_t
qs
[
QUANT_K_MXFP4
/
2
];
};
//struct block_mxfp4_packed16
//{
// uint8_t e;
// uint16_t qs[QUANT_K_MXFP4/2/2];
//};
#if defined(DATA_A_MXFP4)
#define QUANT_K QUANT_K_MXFP4
#define QUANT_R QUANT_R_MXFP4
#define QUANT_AUXF 1
#define A_TYPE block_mxfp4
//#define A_TYPE_PACKED16 block_mxfp4_packed16
#endif
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
...
...
@@ -1397,12 +1410,12 @@ void init_iq_shmem(uvec3 wgsize)
#endif
#if defined(DATA_A_MXFP4)
const
FLOAT_TYPE
kvalues_mxfp4_const
[
16
]
=
{
FLOAT_TYPE
(
0
.
0
f
),
FLOAT_TYPE
(
0
.
5
f
),
FLOAT_TYPE
(
1
.
0
f
),
FLOAT_TYPE
(
1
.
5
f
),
FLOAT_TYPE
(
2
.
0
f
),
FLOAT_TYPE
(
3
.
0
f
),
FLOAT_TYPE
(
4
.
0
f
),
FLOAT_TYPE
(
6
.
0
f
),
FLOAT_TYPE
(
-
0
.
0
f
),
FLOAT_TYPE
(
-
0
.
5
f
),
FLOAT_TYPE
(
-
1
.
0
f
),
FLOAT_TYPE
(
-
1
.
5
f
),
FLOAT_TYPE
(
-
2
.
0
f
),
FLOAT_TYPE
(
-
3
.
0
f
),
FLOAT_TYPE
(
-
4
.
0
f
),
FLOAT_TYPE
(
-
6
.
0
f
)
const
int8_t
kvalues_mxfp4_const
[
16
]
=
{
int8_t
(
0
),
int8_t
(
1
),
int8_t
(
2
),
int8_t
(
3
),
int8_t
(
4
),
int8_t
(
6
),
int8_t
(
8
),
int8_t
(
12
),
int8_t
(
0
),
int8_t
(
-
1
),
int8_t
(
-
2
),
int8_t
(
-
3
),
int8_t
(
-
4
),
int8_t
(
-
6
),
int8_t
(
-
8
),
int8_t
(
-
12
),
};
shared
FLOAT_TYPE
kvalues_mxfp4
[
16
];
shared
int8_t
kvalues_mxfp4
[
16
];
#define NEEDS_INIT_IQ_SHMEM
void
init_iq_shmem
(
uvec3
wgsize
)
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
View file @
3a9e8e9f
...
...
@@ -566,7 +566,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if
(
!
coopmat
&&
!
coopmat2
&&
matmul_id_type
==
MatMulIdType
::
NONE
&&
is_legacy_quant
(
tname
))
{
// Integer dot mmq performs better with f32 accumulators
if
(
!
f16acc
&&
!
coopmat
&&
!
coopmat2
&&
(
is_legacy_quant
(
tname
)
||
is_k_quant
(
tname
)
||
tname
==
"mxfp4"
))
{
string_to_spv
(
shader_name
+
"_"
+
tname
+
"_q8_1"
,
"mul_mmq.comp"
,
merge_maps
(
merge_maps
(
base_dict
,
float_type_dict
),
{{
data_a_key
,
"1"
},
{
"D_TYPE"
,
"float"
},}),
fp16
,
coopmat
,
coopmat2
,
f16acc
);
}
#endif
...
...
@@ -574,7 +575,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
void
process_shaders
()
{
std
::
map
<
std
::
string
,
std
::
string
>
base_dict
=
{{
"FLOAT_TYPE"
,
"float"
}};
std
::
map
<
std
::
string
,
std
::
string
>
base_dict
=
{{
"FLOAT_TYPE"
,
"float"
}
,
{
"FLOAT_TYPE_VEC2"
,
"vec2"
}
};
// matmul
for
(
const
MatMulIdType
&
matmul_id_type
:
{
MatMulIdType
::
NONE
,
MatMulIdType
::
DEFAULT
,
MatMulIdType
::
SUBGROUP
})
{
...
...
@@ -841,10 +842,14 @@ void process_shaders() {
string_to_spv
(
"rope_norm_f32"
,
"rope_norm.comp"
,
{{
"A_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
}});
string_to_spv
(
"rope_norm_f16"
,
"rope_norm.comp"
,
{{
"A_TYPE"
,
"float16_t"
},
{
"D_TYPE"
,
"float16_t"
}});
string_to_spv
(
"rope_norm_f16_rte"
,
"rope_norm.comp"
,
{{
"A_TYPE"
,
"float16_t"
},
{
"D_TYPE"
,
"float16_t"
},
{
"RTE16"
,
"1"
}});
string_to_spv
(
"rope_norm_f32_f16"
,
"rope_norm.comp"
,
{{
"A_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float16_t"
}});
string_to_spv
(
"rope_norm_f32_f16_rte"
,
"rope_norm.comp"
,
{{
"A_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float16_t"
},
{
"RTE16"
,
"1"
}});
string_to_spv
(
"rope_neox_f32"
,
"rope_neox.comp"
,
{{
"A_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
}});
string_to_spv
(
"rope_neox_f16"
,
"rope_neox.comp"
,
{{
"A_TYPE"
,
"float16_t"
},
{
"D_TYPE"
,
"float16_t"
}});
string_to_spv
(
"rope_neox_f16_rte"
,
"rope_neox.comp"
,
{{
"A_TYPE"
,
"float16_t"
},
{
"D_TYPE"
,
"float16_t"
},
{
"RTE16"
,
"1"
}});
string_to_spv
(
"rope_neox_f32_f16"
,
"rope_neox.comp"
,
{{
"A_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float16_t"
}});
string_to_spv
(
"rope_neox_f32_f16_rte"
,
"rope_neox.comp"
,
{{
"A_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float16_t"
},
{
"RTE16"
,
"1"
}});
string_to_spv
(
"rope_multi_f32"
,
"rope_multi.comp"
,
{{
"A_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
}});
string_to_spv
(
"rope_multi_f16"
,
"rope_multi.comp"
,
{{
"A_TYPE"
,
"float16_t"
},
{
"D_TYPE"
,
"float16_t"
}});
...
...
Prev
1
2
Next
xuxzh1
🎱
@xuxzh1
mentioned in commit
0cf7794b
·
Jan 09, 2026
mentioned in commit
0cf7794b
mentioned in commit 0cf7794b16fab8d4561bc5f6379f6d48bd59e101
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment