Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d5c48001
Unverified
Commit
d5c48001
authored
Feb 05, 2026
by
Hashem Hashemi
Committed by
GitHub
Feb 05, 2026
Browse files
Adds padding and perf improvements to wvSplitK_fp8 (#33527)
Signed-off-by:
Hashem Hashemi
<
hashem.hashemi@amd.com
>
parent
42d5d705
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
169 additions
and
229 deletions
+169
-229
csrc/rocm/skinny_gemms.cu
csrc/rocm/skinny_gemms.cu
+126
-174
tests/kernels/quantization/test_rocm_skinny_gemms.py
tests/kernels/quantization/test_rocm_skinny_gemms.py
+40
-52
vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py
...el_executor/layers/quantization/kernels/scaled_mm/rocm.py
+3
-3
No files found.
csrc/rocm/skinny_gemms.cu
View file @
d5c48001
...
...
@@ -1899,8 +1899,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
wvSplitKQ_hf_sml_
(
const
int
K
,
const
int
Kp
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
fp8_t
*
B
,
const
fp8_t
*
__restrict__
A
,
wvSplitKQ_hf_sml_
(
const
int
K
,
const
int
Kap
,
const
int
Kbp
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
fp8_t
*
B
,
const
fp8_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
scalar_t
*
C
,
const
float
*
__restrict__
s_A
,
const
float
*
__restrict__
s_B
,
const
int
_WvPrGrp
,
...
...
@@ -1924,9 +1925,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
__shared__
fp8_t
s
[
max_lds_len
];
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
k
<
min__
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
k
<
min__
(
Kap
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds
((
int
*
)(
&
A
[
k
]),
(
int
*
)(
&
s
[
k
]),
16
,
0
,
0
);
#else
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
#endif
}
asm
volatile
(
"s_waitcnt vmcnt(0)"
);
__syncthreads
();
if
(
threadIdx
.
y
>=
_WvPrGrp
)
return
;
...
...
@@ -1934,37 +1940,24 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t
m
=
(
blockIdx
.
x
*
_WvPrGrp
+
(
threadIdx
.
y
%
_WvPrGrp
))
*
YTILE
;
using
floatx16
=
__attribute__
((
__vector_size__
(
16
*
sizeof
(
float
))))
float
;
floatx16
sum
[
N
][
YTILE
];
float
sA
=
*
s_A
;
float
sB
=
*
s_B
;
while
(
m
<
M
)
{
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
sum
[
n
][
i
]
=
{
0.
f
};
bigType
bigA
[
N
][
UNRL
];
bigType
bigB
[
YTILE
][
UNRL
];
floatx16
sum
[
N
][
YTILE
]
=
{};
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
#pragma unroll
for
(
uint32_t
n
=
0
;
n
<
N
;
++
n
)
bigA
[
n
][
k2
].
h8
=
{
0.
f
};
#pragma unroll
for
(
uint32_t
y
=
0
;
y
<
YTILE
;
++
y
)
bigB
[
y
][
k2
].
h8
=
{
0.
f
};
}
bigType
bigA
[
N
][
UNRL
]
=
{};
bigType
bigB
[
YTILE
][
UNRL
];
// Fetch the weight matrix from memory!
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
const
fp8_t
*
B_
=
&
B
[(
m
+
0
)
*
Kp
+
k_
];
const
fp8_t
*
B_
=
&
B
[
min__
(
k_
,
K
-
A_CHUNK
)];
#pragma unroll
for
(
uint32_t
y
=
0
;
y
<
YTILE
;
++
y
)
{
bigB
[
y
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
y
*
Kp
])));
bigB
[
y
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
min__
(
y
+
m
,
M
-
1
)
*
K
b
p
])));
}
}
...
...
@@ -1975,16 +1968,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
ap
*
n
])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
if
(
k
>=
K
)
break
;
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
A_CHUNK
;
i
+=
8
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
++
y
)
{
...
...
@@ -2002,48 +1992,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
float
accm0
=
sum
[
n
][
y
][
0
];
float
accm16
=
sum
[
n
][
y
][
8
];
asm
(
"v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
1
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
9
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
2
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
10
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
3
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
11
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
4
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
12
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
5
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
13
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
6
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
14
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
7
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
15
]),
"v"
(
accm16
));
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
1
],
0x101
,
0xf
,
0xf
,
1
);
// row_shl1
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
9
],
0x101
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
2
],
0x102
,
0xf
,
0xf
,
1
);
// row_shl2
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
10
],
0x102
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
3
],
0x103
,
0xf
,
0xf
,
1
);
// row_shl3
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
11
],
0x103
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
4
],
0x108
,
0xf
,
0xf
,
1
);
// row_shl8
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
12
],
0x108
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
5
],
0x109
,
0xf
,
0xf
,
1
);
// row_shl9
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
13
],
0x109
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
6
],
0x10a
,
0xf
,
0xf
,
1
);
// row_shl10
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
14
],
0x10a
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
7
],
0x10b
,
0xf
,
0xf
,
1
);
// row_shl11
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
15
],
0x10b
,
0xf
,
0xf
,
1
);
accm0
+=
__shfl
(
accm0
,
36
);
accm16
+=
__shfl
(
accm16
,
52
);
sum
[
n
][
y
][
0
]
=
accm0
+
__shfl
(
accm16
,
16
);
...
...
@@ -2051,19 +2020,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
if
(
threadIdx
.
x
==
0
)
{
scalar_t
biases
[
N
][
YTILE
]
=
{};
if
(
BIAS
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
biases
[
n
][
y
]
=
BIAS
[(
m
+
y
)
%
Bx
+
(
n
%
By
)
*
Bx
];
}
}
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
if
(
y
+
m
>=
M
)
break
;
// To avoid mem access fault.
sum
[
n
][
y
][
0
]
*=
sA
*
sB
;
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
half
>
)
{
if
(
BIAS
)
sum
[
n
][
y
][
0
]
+=
__half2float
(
BIAS
[(
m
+
y
)
%
Bx
+
(
n
%
By
)
*
M
]);
sum
[
n
][
y
][
0
]
+=
__half2float
(
biases
[
n
][
y
]);
}
else
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
)
{
if
(
BIAS
)
sum
[
n
][
y
][
0
]
+=
__bfloat162float
(
BIAS
[(
m
+
y
)
%
Bx
+
(
n
%
By
)
*
M
]);
sum
[
n
][
y
][
0
]
+=
__bfloat162float
(
biases
[
n
][
y
]);
}
C
[
m
+
y
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
y
][
0
]);
// * sA * sB);
C
[
m
+
y
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
y
][
0
]);
}
}
}
...
...
@@ -2074,9 +2047,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitKQ_hf_sml_
(
const
int
K
,
const
int
Kp
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
fp8_t
*
B
,
const
fp8_t
*
__restrict__
A
,
__global__
void
wvSplitKQ_hf_sml_
(
const
int
K
,
const
int
K
a
p
,
const
int
Kbp
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
fp8_t
*
B
,
const
fp8_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
scalar_t
*
C
,
const
float
*
__restrict__
s_A
,
const
float
*
__restrict__
s_B
,
...
...
@@ -2089,8 +2062,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
wvSplitKQ_hf_
(
const
int
K
,
const
int
Kp
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
fp8_t
*
B
,
const
fp8_t
*
__restrict__
A
,
wvSplitKQ_hf_
(
const
int
K
,
const
int
Kap
,
const
int
Kbp
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
fp8_t
*
B
,
const
fp8_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
scalar_t
*
C
,
const
float
*
__restrict__
s_A
,
const
float
*
__restrict__
s_B
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
...
...
@@ -2113,9 +2087,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
__shared__
fp8_t
s
[
max_lds_len
];
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
k
<
min__
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
k
<
min__
(
Kap
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds
((
int
*
)(
&
A
[
k
]),
(
int
*
)(
&
s
[
k
]),
16
,
0
,
0
);
#else
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
#endif
}
asm
volatile
(
"s_waitcnt vmcnt(0)"
);
__syncthreads
();
if
(
threadIdx
.
y
>=
_WvPrGrp
)
return
;
...
...
@@ -2123,29 +2102,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t
m
=
(
blockIdx
.
x
*
_WvPrGrp
+
(
threadIdx
.
y
%
_WvPrGrp
))
*
YTILE
;
using
floatx16
=
__attribute__
((
__vector_size__
(
16
*
sizeof
(
float
))))
float
;
floatx16
sum
[
N
][
YTILE
];
float
sA
=
*
s_A
;
float
sB
=
*
s_B
;
while
(
m
<
M
)
{
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
sum
[
n
][
i
]
=
{
0
};
bigType
bigA
[
N
][
UNRL
];
floatx16
sum
[
N
][
YTILE
]
=
{};
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
bigType
bigA
[
N
][
UNRL
]
=
{};
bigType
bigB
[
YTILE
][
UNRL
];
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
// Fetch the weight matrix from memory!
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
const
fp8_t
*
B_
=
&
B
[(
m
+
0
)
*
Kp
+
k_
];
const
fp8_t
*
B_
=
&
B
[
min__
(
k_
,
K
-
A_CHUNK
)];
for
(
int
y
=
0
;
y
<
YTILE
;
++
y
)
{
if
(
y
+
m
>=
M
)
break
;
// To avoid mem access fault.
bigB
[
y
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
y
*
Kp
])));
bigB
[
y
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
min__
(
y
+
m
,
M
-
1
)
*
Kbp
])));
}
}
...
...
@@ -2156,20 +2129,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
if
(
k_
+
K
*
n
<
max_lds_len
)
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
if
(
k_
+
K
ap
*
n
<
max_lds_len
)
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
ap
*
n
])));
else
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
A
[
k_
+
K
*
n
])));
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
A
[
k_
+
K
ap
*
n
])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
A_CHUNK
;
i
+=
8
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
++
y
)
{
...
...
@@ -2187,48 +2156,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
float
accm0
=
sum
[
n
][
y
][
0
];
float
accm16
=
sum
[
n
][
y
][
8
];
asm
(
"v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
1
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
9
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
2
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
10
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
3
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
11
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
4
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
12
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
5
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
13
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
6
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
14
]),
"v"
(
accm16
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
:
"=v"
(
accm0
)
:
"0"
(
accm0
),
"v"
(
sum
[
n
][
y
][
7
]),
"v"
(
accm0
));
asm
(
"v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
:
"=v"
(
accm16
)
:
"0"
(
accm16
),
"v"
(
sum
[
n
][
y
][
15
]),
"v"
(
accm16
));
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
1
],
0x101
,
0xf
,
0xf
,
1
);
// row_shl1
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
9
],
0x101
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
2
],
0x102
,
0xf
,
0xf
,
1
);
// row_shl2
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
10
],
0x102
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
3
],
0x103
,
0xf
,
0xf
,
1
);
// row_shl3
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
11
],
0x103
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
4
],
0x108
,
0xf
,
0xf
,
1
);
// row_shl8
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
12
],
0x108
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
5
],
0x109
,
0xf
,
0xf
,
1
);
// row_shl9
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
13
],
0x109
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
6
],
0x10a
,
0xf
,
0xf
,
1
);
// row_shl10
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
14
],
0x10a
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
7
],
0x10b
,
0xf
,
0xf
,
1
);
// row_shl11
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
15
],
0x10b
,
0xf
,
0xf
,
1
);
accm0
+=
__shfl
(
accm0
,
36
);
accm16
+=
__shfl
(
accm16
,
52
);
sum
[
n
][
y
][
0
]
=
accm0
+
__shfl
(
accm16
,
16
);
...
...
@@ -2236,17 +2184,21 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
if
(
threadIdx
.
x
==
0
)
{
scalar_t
biases
[
N
][
YTILE
]
=
{};
if
(
BIAS
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
biases
[
n
][
y
]
=
BIAS
[(
m
+
y
)
%
Bx
+
(
n
%
By
)
*
Bx
];
}
}
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
if
(
y
+
m
>=
M
)
break
;
// To avoid mem access fault.
sum
[
n
][
y
][
0
]
*=
sA
*
sB
;
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
half
>
)
{
if
(
BIAS
)
sum
[
n
][
y
][
0
]
+=
__half2float
(
BIAS
[(
m
+
y
)
%
Bx
+
(
n
%
By
)
*
M
]);
sum
[
n
][
y
][
0
]
+=
__half2float
(
biases
[
n
][
y
]);
}
else
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
)
{
if
(
BIAS
)
sum
[
n
][
y
][
0
]
+=
__bfloat162float
(
BIAS
[(
m
+
y
)
%
Bx
+
(
n
%
By
)
*
M
]);
sum
[
n
][
y
][
0
]
+=
__bfloat162float
(
biases
[
n
][
y
]);
}
C
[
m
+
y
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
y
][
0
]);
}
...
...
@@ -2259,9 +2211,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitKQ_hf_
(
const
int
K
,
const
int
Kp
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
fp8_t
*
B
,
const
fp8_t
*
__restrict__
A
,
__global__
void
wvSplitKQ_hf_
(
const
int
K
,
const
int
K
a
p
,
const
int
Kbp
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
fp8_t
*
B
,
const
fp8_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
scalar_t
*
C
,
const
float
*
__restrict__
s_A
,
const
float
*
__restrict__
s_B
,
const
int
_WvPrGrp
,
...
...
@@ -2270,17 +2222,18 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
}
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
void
wvSplitKQ
(
const
at
::
Tensor
&
in_
a
,
const
at
::
Tensor
&
in_
b
,
void
wvSplitKQ
(
const
at
::
Tensor
&
in_
b
,
const
at
::
Tensor
&
in_
a
,
const
std
::
optional
<
at
::
Tensor
>&
in_bias
,
at
::
Tensor
&
out_c
,
const
at
::
Tensor
&
scale_a
,
const
at
::
Tensor
&
scale_b
,
const
int64_t
CuCount
)
{
static
c10
::
ScalarType
kFp8Type
=
is_fp8_ocp
()
?
c10
::
ScalarType
::
Float8_e4m3fn
:
c10
::
ScalarType
::
Float8_e4m3fnuz
;
auto
M_in
=
in_a
.
size
(
0
);
auto
K_in
=
in_a
.
size
(
1
);
auto
N_in
=
in_b
.
size
(
0
);
auto
Kp_in
=
in_a
.
stride
(
0
);
auto
M_in
=
in_b
.
size
(
0
);
auto
K_in
=
in_b
.
size
(
1
);
auto
N_in
=
in_a
.
size
(
0
);
auto
Kap_in
=
in_a
.
stride
(
0
);
auto
Kbp_in
=
in_b
.
stride
(
0
);
auto
Bx_in
=
(
in_bias
.
has_value
()
&&
in_bias
->
numel
()
>
0
)
?
(
in_bias
->
sizes
().
size
()
==
2
)
?
in_bias
->
size
(
1
)
:
in_bias
->
size
(
0
)
...
...
@@ -2300,22 +2253,21 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
max_lds_len
=
get_lds_size
();
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) {
\
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs,
_WvPrGrp);
\
if ((K
ap
_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp =
min(_WvPrGrp,
mindiv(M_in, CuCount * _YTILEs,
16));
\
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in,
By_in, a_ptr,
\
b
_ptr, bias_ptr, c_ptr,
s_a, s_b,
\
__wvPrGrp, CuCount);
\
<<<grid, block, 0, stream>>>(K_in, K
ap_in, Kb
p_in, M_in, Bx_in,
\
By_in, b_ptr, a
_ptr, bias_ptr, c_ptr, \
s_a, s_b,
__wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm,
_WvPrGrp);
\
int __wvPrGrp =
min(_WvPrGrp,
mindiv(M_in, CuCount * _YTILEm,
16));
\
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in,
By_in, a_ptr,
\
b
_ptr, bias_ptr, c_ptr,
s_a, s_b,
\
__wvPrGrp, CuCount);
\
<<<grid, block, 0, stream>>>(K_in, K
ap_in, Kb
p_in, M_in, Bx_in,
\
By_in, b_ptr, a
_ptr, bias_ptr, c_ptr, \
s_a, s_b,
__wvPrGrp, CuCount); \
} \
}
...
...
@@ -2332,16 +2284,16 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
:
nullptr
;
switch
(
N_in
)
{
case
1
:
WVSPLITKQ
(
1
6
,
2
,
2
,
2
,
2
,
2
,
2
,
1
)
WVSPLITKQ
(
12
,
2
,
2
,
2
,
2
,
1
)
break
;
case
2
:
WVSPLITKQ
(
1
6
,
2
,
2
,
2
,
2
,
2
,
2
,
2
)
WVSPLITKQ
(
12
,
2
,
2
,
2
,
2
,
2
)
break
;
case
3
:
WVSPLITKQ
(
16
,
4
,
7
,
7
,
1
,
1
,
1
,
3
)
WVSPLITKQ
(
8
,
2
,
2
,
1
,
1
,
3
)
break
;
case
4
:
WVSPLITKQ
(
16
,
4
,
7
,
7
,
1
,
1
,
1
,
4
)
WVSPLITKQ
(
4
,
2
,
2
,
1
,
1
,
4
)
break
;
default:
throw
std
::
runtime_error
(
...
...
tests/kernels/quantization/test_rocm_skinny_gemms.py
View file @
d5c48001
...
...
@@ -73,21 +73,40 @@ NKM_FACTORS_WVSPLITKRC = [
NKM_FACTORS_WVSPLITK_FP8
=
[
# FP8-specific cases with K % 16 == 0
(
1
,
16
,
16
),
(
1
,
32
,
16
+
16
),
(
1
,
64
,
64
),
(
1
,
64
,
64
+
16
),
(
1
,
64
+
16
,
64
),
(
1
,
64
+
16
,
64
+
16
),
(
4
,
64
,
64
),
(
4
,
64
,
64
+
16
),
(
4
,
64
+
16
,
64
),
(
4
,
64
+
16
,
64
+
16
),
(
2
,
512
,
512
),
(
3
,
512
,
512
),
(
3
,
512
,
512
+
16
),
(
4
,
512
,
512
),
(
3
,
2048
,
2048
),
(
3
,
2048
,
2048
+
16
),
(
4
,
2048
+
16
,
2048
),
(
4
,
2048
+
16
,
2048
+
16
),
(
4
,
4096
,
4096
),
(
4
,
16400
,
2048
),
(
4
,
16400
,
2048
+
16
),
# Extended FP8 dimensions not covered by WVSPLITK
(
1
,
14336
,
1024
),
(
2
,
24576
,
2048
),
(
4
,
32768
,
28672
),
(
4
,
32768
*
2
,
28672
),
(
4
,
32768
*
2
,
28672
+
16
),
(
4
,
32768
*
2
+
16
,
28672
),
(
4
,
32768
*
2
+
16
,
28672
+
16
),
]
SEEDS
=
[
0
]
def
pad_
weights_
fp8
(
weight
):
def
pad_fp8
(
weight
):
num_pad
=
256
//
weight
.
element_size
()
import
torch.nn.functional
as
F
...
...
@@ -195,72 +214,41 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
assert
torch
.
allclose
(
out
,
ref_out
,
rtol
=
0.01
)
@
pytest
.
mark
.
parametrize
(
"xnorm"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_WVSPLITK_FP8
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"padded"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"padded_a"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"padded_b"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"biased"
,
[
False
,
True
])
@
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_rocm
()
and
current_platform
.
supports_fp8
()),
reason
=
"only test for rocm fp8"
,
)
def
test_rocm_wvsplitk_fp8_kernel
(
n
,
k
,
m
,
dtype
,
seed
,
padded
):
def
test_rocm_wvsplitk_fp8_kernel
(
xnorm
,
n
,
k
,
m
,
dtype
,
seed
,
padded_a
,
padded_b
,
biased
):
torch
.
manual_seed
(
seed
)
A
=
torch
.
rand
(
n
,
k
,
device
=
"cuda"
)
-
0.5
B
=
torch
.
rand
(
m
,
k
,
device
=
"cuda"
)
-
0.5
xavier
=
math
.
sqrt
(
2
/
k
)
if
xnorm
else
1
# normalize to avoid large deltas
A
=
(
torch
.
rand
(
n
,
k
,
device
=
"cuda"
)
*
2
-
1
)
*
xavier
B
=
(
torch
.
rand
(
m
,
k
,
device
=
"cuda"
)
*
2
-
1
)
*
xavier
A
,
scale_a
=
ref_dynamic_per_tensor_fp8_quant
(
A
)
B
,
scale_b
=
ref_dynamic_per_tensor_fp8_quant
(
B
)
if
padded
:
B
=
pad_weights_fp8
(
B
)
ref_out
=
torch
.
_scaled_mm
(
A
,
B
.
t
(),
out_dtype
=
dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
)
out
=
ops
.
wvSplitKQ
(
B
,
A
,
dtype
,
scale_a
,
scale_b
,
get_cu_count
(),
)
assert
torch
.
allclose
(
out
,
ref_out
,
rtol
=
0.01
)
if
padded_b
:
B
=
pad_fp8
(
B
)
if
padded_a
:
A
=
pad_fp8
(
A
)
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_WVSPLITK_FP8
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"padded"
,
[
False
,
True
])
@
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_rocm
()
and
current_platform
.
supports_fp8
()),
reason
=
"only test for rocm fp8"
,
)
def
test_rocm_wvsplitk_fp8_bias1D_kernel
(
n
,
k
,
m
,
dtype
,
seed
,
padded
):
torch
.
manual_seed
(
seed
)
xavier
=
math
.
sqrt
(
2
/
k
)
# normalize to avoid large output-bias deltas
A
=
(
torch
.
rand
(
n
,
k
,
device
=
"cuda"
)
-
0.5
)
*
xavier
B
=
(
torch
.
rand
(
m
,
k
,
device
=
"cuda"
)
-
0.5
)
*
xavier
BIAS
=
torch
.
rand
(
m
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
A
,
scale_a
=
ref_dynamic_per_tensor_fp8_quant
(
A
)
B
,
scale_b
=
ref_dynamic_per_tensor_fp8_quant
(
B
)
if
padded
:
B
=
pad_weights_fp8
(
B
)
BIAS
=
None
if
(
not
biased
)
else
(
torch
.
rand
(
m
,
dtype
=
dtype
,
device
=
"cuda"
)
*
2
-
1
)
ref_out
=
torch
.
_scaled_mm
(
A
,
B
.
t
(),
out_dtype
=
dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
bias
=
BIAS
)
out
=
ops
.
wvSplitKQ
(
B
,
A
,
dtype
,
scale_a
,
scale_b
,
get_cu_count
(),
BIAS
,
)
out
=
ops
.
wvSplitKQ
(
B
,
A
,
dtype
,
scale_a
,
scale_b
,
get_cu_count
(),
BIAS
)
assert
torch
.
allclose
(
out
,
ref_out
,
rtol
=
0.01
)
if
xnorm
:
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-3
,
rtol
=
1e-8
)
else
:
assert
torch
.
allclose
(
out
,
ref_out
,
0.01
)
vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py
View file @
d5c48001
...
...
@@ -25,10 +25,10 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
bias
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
(
A
.
shape
[
0
]
==
1
and
B
.
shape
[
1
]
%
16
==
0
A
.
shape
[
0
]
<=
4
and
B
.
shape
[
0
]
%
16
==
0
# M TODO: needed?
and
B
.
shape
[
1
]
%
16
==
0
# K
and
((
bias
is
None
)
or
(
bias
.
dtype
==
out_dtype
))
and
A
.
is_contiguous
()
):
output
=
ops
.
wvSplitKQ
(
B
.
t
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment