Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7600642e
Unverified
Commit
7600642e
authored
Feb 28, 2026
by
Hashem Hashemi
Committed by
GitHub
Feb 28, 2026
Browse files
Add padding support to wvSplitK solution for skinny GEMMs (#33762)
Signed-off-by:
Hashem Hashemi
<
hashem.hashemi@amd.com
>
parent
1e69c048
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
289 additions
and
444 deletions
+289
-444
csrc/rocm/skinny_gemms.cu
csrc/rocm/skinny_gemms.cu
+256
-402
tests/kernels/quantization/test_rocm_skinny_gemms.py
tests/kernels/quantization/test_rocm_skinny_gemms.py
+33
-41
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+0
-1
No files found.
csrc/rocm/skinny_gemms.cu
View file @
7600642e
...
...
@@ -304,8 +304,9 @@ __device__ inline unsigned int min__(uint32_t a, uint32_t b) {
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
wvSplitK_hf_sml_
(
const
int
K
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
wvSplitK_hf_sml_
(
const
int
K
,
const
int
Kbp
,
const
int
Kap
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
constexpr
int
max_lds_len
=
LDS_SIZE
/
2
;
...
...
@@ -314,7 +315,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else
constexpr
bool
use_mfma
=
false
;
#endif
using
scalar8
=
__attribute__
((
__vector_size__
((
A_CHUNK
/
2
)
*
sizeof
(
float
))))
float
;
using
half4
=
...
...
@@ -346,13 +346,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for
(
uint32_t
k
=
0
;
k
<
min__
(
K
*
N
,
max_lds_len
)
;
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
k_in
>=
min__
(
K
*
N
,
max_lds_len
))
break
;
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
k
<
min__
(
Kap
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds
((
int
*
)(
&
A
[
k
]),
(
int
*
)(
&
s
[
k
]),
16
,
0
,
0
);
#else
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
#endif
}
__syncthreads
();
...
...
@@ -360,9 +360,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t
m
=
(
blockIdx
.
x
*
_WvPrGrp
+
(
threadIdx
.
y
%
_WvPrGrp
))
*
YTILE
;
float
sum
[
N
][
YTILE
];
scalar8
sum4
[
N
][
YTILE
];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
...
...
@@ -386,44 +383,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
if
constexpr
(
!
use_mfma
)
sum
[
n
][
i
]
=
0
;
else
sum4
[
n
][
i
]
=
{
0
,
0
,
0
,
0
};
bigType
bigA
[
N
][
UNRL
];
bigType
bigB
[
YTILE
][
UNRL
];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
// for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
float
sum
[
N
][
YTILE
]
=
{};
scalar8
sum4
[
N
][
YTILE
]
=
{};
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
bigType
bigA
[
N
][
UNRL
]
=
{};
bigType
bigB
[
YTILE
][
UNRL
];
// Fetch the weight matrix from memory!
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
const
scalar_t
*
B_
=
&
B
[(
m
+
0
)
*
K
+
k_
];
const
scalar_t
*
B_
=
&
B
[
min__
(
k_
,
K
-
A_CHUNK
)];
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
bigB
[
y
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
y
*
K
])));
bigB
[
y
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
min__
(
y
+
m
,
M
-
1
)
*
K
bp
])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
...
...
@@ -432,33 +405,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
ap
*
n
])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
#pragma unroll
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
if
constexpr
(
!
use_mfma
)
#pragma unroll
for
(
uint32_t
b
=
0
;
b
<
A_CHUNK
/
2
;
b
++
)
{
DOT2C
(
sum
[
n
][
y
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
y
][
k2
].
f
[
b
])
}
else
#pragma unroll
for
(
uint32_t
b
=
0
;
b
<
A_CHUNK
/
4
;
b
++
)
sum4
[
n
][
y
]
=
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
bigA
[
n
][
k2
].
h4
[
b
],
bigB
[
y
][
k2
].
h4
[
b
],
sum4
[
n
][
y
],
0
,
0
,
0
);
...
...
@@ -466,46 +426,44 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
if
constexpr
(
!
use_mfma
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x118
,
0xf
,
0xf
,
1
);
// row_shr8
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x114
,
0xf
,
0xf
,
1
);
// row_shr4
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x112
,
0xf
,
0xf
,
1
);
// row_shr2
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x111
,
0xf
,
0xf
,
1
);
// row_shr1
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x142
,
0xf
,
0xf
,
1
);
// ROW_BCAST15
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x143
,
0xf
,
0xf
,
1
);
// ROW_BCAST31
}
}
if
(
threadIdx
.
x
==
63
)
{
scalar_t
biases
[
N
][
YTILE
]
=
{};
if
(
BIAS
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
biases
[
n
][
y
]
=
BIAS
[(
m
+
y
)
%
Bx
+
(
n
%
By
)
*
Bx
];
}
}
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
half
>
)
{
if
(
BIAS
)
sum
[
n
][
i
]
+=
__half2float
(
BIAS
[(
m
+
i
)
%
Bx
+
(
n
%
By
)
*
M
]);
sum
[
n
][
y
]
+=
__half2float
(
biases
[
n
][
y
]);
}
else
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
)
{
if
(
BIAS
)
sum
[
n
][
i
]
+=
__bfloat162float
(
BIAS
[(
m
+
i
)
%
Bx
+
(
n
%
By
)
*
M
]);
sum
[
n
][
y
]
+=
__bfloat162float
(
biases
[
n
][
y
]);
}
C
[
m
+
i
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
i
]);
C
[
m
+
y
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
y
]);
}
}
}
...
...
@@ -514,45 +472,43 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
#pragma unroll
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
// float accm1 = 0;
// for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i);
/*float accm1 = 0;
for (int i=0; i<64; i++)
accm1 += __shfl(sum4[n][y][i%4], i);
sum4[n][y][0] = accm1;*/
float
accm
=
sum4
[
n
][
y
][
0
];
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
sum4
[
n
][
y
][
1
]),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
sum4
[
n
][
y
][
2
]),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
sum4
[
n
][
y
][
3
]),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
),
"v"
(
accm
));
accm
+=
__builtin_amdgcn_mov_dpp
(
sum4
[
n
][
y
][
1
],
0x101
,
0xf
,
0xf
,
1
);
// row_shl1
accm
+=
__builtin_amdgcn_mov_dpp
(
sum4
[
n
][
y
][
2
],
0x102
,
0xf
,
0xf
,
1
);
// row_shl2
accm
+=
__builtin_amdgcn_mov_dpp
(
sum4
[
n
][
y
][
3
],
0x103
,
0xf
,
0xf
,
1
);
// row_shl3
accm
+=
__builtin_amdgcn_mov_dpp
(
accm
,
0x104
,
0xf
,
0xf
,
1
);
// row_shl4
accm
+=
__builtin_amdgcn_mov_dpp
(
accm
,
0x108
,
0xf
,
0xf
,
1
);
// row_shl8
accm
=
__builtin_amdgcn_mov_dpp
(
accm
,
0x11f
,
0xf
,
0xf
,
1
);
// row_shr15
accm
+=
__builtin_amdgcn_mov_dpp
(
accm
,
0x142
,
0xf
,
0xf
,
1
);
// ROW_BCAST15
accm
+=
__builtin_amdgcn_mov_dpp
(
accm
,
0x143
,
0xf
,
0xf
,
1
);
// ROW_BCAST31
sum4
[
n
][
y
][
0
]
=
accm
;
}
}
if
(
threadIdx
.
x
==
63
)
{
scalar_t
biases
[
N
][
YTILE
]
=
{};
if
(
BIAS
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
biases
[
n
][
y
]
=
BIAS
[(
m
+
y
)
%
Bx
+
(
n
%
By
)
*
Bx
];
}
}
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
{
if
(
BIAS
)
sum4
[
n
][
i
][
0
]
+=
__bfloat162float
(
BIAS
[(
m
+
i
)
%
Bx
+
(
n
%
By
)
*
M
]);
C
[
m
+
i
+
n
*
M
]
=
__float2bfloat16
(
sum4
[
n
][
i
][
0
]);
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
sum4
[
n
][
y
][
0
]
+=
__bfloat162float
(
biases
[
n
][
y
]);
C
[
m
+
y
+
n
*
M
]
=
__float2bfloat16
(
sum4
[
n
][
y
][
0
]);
}
}
}
...
...
@@ -563,8 +519,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitK_hf_sml_
(
const
int
K
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
__global__
void
wvSplitK_hf_sml_
(
const
int
K
,
const
int
Kbp
,
const
int
Kap
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
...
...
@@ -577,8 +534,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
wvSplitK_hf_
(
const
int
K
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
wvSplitK_hf_
(
const
int
K
,
const
int
Kbp
,
const
int
Kap
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
constexpr
int
max_lds_len
=
LDS_SIZE
/
2
;
...
...
@@ -601,13 +559,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
scalar8
h8
;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not going to work!
//----------------------------------------------------
__shared__
scalar_t
s
[
max_lds_len
];
//----------------------------------------------------
...
...
@@ -618,12 +569,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
commitColumn
[
i
]
=
1
;
}
//----------------------------------------------------
// Indexing function into the column of weight matrix B
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
// int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
uint32_t
m
=
(
blockIdx
.
x
*
_WvPrGrp
+
threadIdx
.
y
)
*
YTILE
;
// Check whether there will be fragmentation!
...
...
@@ -636,91 +581,34 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m
=
startColumn
;
}
//----------------------------------------------------
// Fetch the activation matrix to LDS
// Loop iteration:
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements
// - Each WG will fetch 512 * 16 => 8K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for
(
uint32_t
k
=
0
;
k
<
min__
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
k_in
>=
min__
(
K
*
N
,
max_lds_len
))
break
;
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
k
<
min__
(
Kap
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds
((
int
*
)(
&
A
[
k
]),
(
int
*
)(
&
s
[
k
]),
16
,
0
,
0
);
#else
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
#endif
}
__syncthreads
();
if
(
threadIdx
.
y
>=
_WvPrGrp
)
return
;
float
sum
[
N
][
YTILE
];
scalar8
sum4
[
N
][
YTILE
];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
// working on 16 columns of weight matrix. Moreover,
// we tile in column direction by YTILE, so when YTILE=1
// the above math is right, however, when YTILE=2 then
// each wave will be working on 2 columns and WG will
// be working on 32 columns.
//
// Top level loop that makes WGs persistent!
// - WGs iterates across columns of weight matrix
// - Each wave within WG works on a given column(s)
// - After completing first set of columns, WGs start
// working on the next set of available columns
//----------------------------------------------------
while
(
m
<
M
)
{
//----------------------------------------------------
// 'sum' accumulates the matrix A x B computation
// split across 64 lanes.
//
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
if
constexpr
(
!
use_mfma
)
sum
[
n
][
i
]
=
0
;
else
sum4
[
n
][
i
]
=
{
0
,
0
,
0
,
0
};
bigType
bigA
[
N
][
UNRL
];
bigType
bigB
[
YTILE
][
UNRL
];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
float
sum
[
N
][
YTILE
]
=
{};
scalar8
sum4
[
N
][
YTILE
]
=
{};
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
bigType
bigA
[
N
][
UNRL
]
=
{};
bigType
bigB
[
YTILE
][
UNRL
];
// Fetch the weight matrix from memory!
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
const
scalar_t
*
B_
=
&
B
[(
m
+
0
)
*
K
+
k_
];
for
(
int
b
=
0
;
b
<
YTILE
;
b
++
)
bigB
[
b
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
b
*
K
])));
const
scalar_t
*
B_
=
&
B
[
min__
(
k_
,
K
-
A_CHUNK
)];
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
bigB
[
y
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
min__
(
y
+
m
,
M
-
1
)
*
Kbp
])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
...
...
@@ -729,36 +617,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
if
(
k_
+
K
*
n
<
max_lds_len
)
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
if
(
k_
+
K
ap
*
n
<
max_lds_len
)
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
ap
*
n
])));
else
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
A
[
k_
+
K
*
n
])));
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
A
[
k_
+
K
ap
*
n
])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
if
constexpr
(
!
use_mfma
)
#pragma unroll
for
(
uint32_t
b
=
0
;
b
<
A_CHUNK
/
2
;
b
++
)
{
DOT2C
(
sum
[
n
][
y
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
y
][
k2
].
f
[
b
])
}
else
#pragma unroll
for
(
uint32_t
b
=
0
;
b
<
A_CHUNK
/
4
;
b
++
)
sum4
[
n
][
y
]
=
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
bigA
[
n
][
k2
].
h4
[
b
],
bigB
[
y
][
k2
].
h4
[
b
],
sum4
[
n
][
y
],
0
,
0
,
0
);
...
...
@@ -773,40 +648,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if
constexpr
(
!
use_mfma
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x118
,
0xf
,
0xf
,
1
);
// row_shr8
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x114
,
0xf
,
0xf
,
1
);
// row_shr4
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x112
,
0xf
,
0xf
,
1
);
// row_shr2
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x111
,
0xf
,
0xf
,
1
);
// row_shr1
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x142
,
0xf
,
0xf
,
1
);
// ROW_BCAST15
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x143
,
0xf
,
0xf
,
1
);
// ROW_BCAST31
}
}
if
(
threadIdx
.
x
==
63
)
{
scalar_t
biases
[
N
][
YTILE
]
=
{};
if
(
BIAS
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
biases
[
n
][
y
]
=
BIAS
[(
m
+
y
)
%
Bx
+
(
n
%
By
)
*
Bx
];
}
}
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
{
if
(
commitColumn
[
i
])
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
if
(
commitColumn
[
y
])
{
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
half
>
)
{
if
(
BIAS
)
sum
[
n
][
i
]
+=
__half2float
(
BIAS
[(
m
+
i
)
%
Bx
+
(
n
%
By
)
*
M
]);
sum
[
n
][
y
]
+=
__half2float
(
biases
[
n
][
y
]);
}
else
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
)
{
if
(
BIAS
)
sum
[
n
][
i
]
+=
__bfloat162float
(
BIAS
[(
m
+
i
)
%
Bx
+
(
n
%
By
)
*
M
]);
sum
[
n
][
y
]
+=
__bfloat162float
(
biases
[
n
][
y
]);
}
C
[
m
+
i
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
i
]);
C
[
m
+
y
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
y
]);
}
}
}
...
...
@@ -819,44 +692,39 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// float accm1 = 0;
// for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i);
float
accm
=
sum4
[
n
][
y
][
0
];
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
sum4
[
n
][
y
][
1
]),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
sum4
[
n
][
y
][
2
]),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
sum4
[
n
][
y
][
3
]),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
),
"v"
(
accm
));
accm
+=
__builtin_amdgcn_mov_dpp
(
sum4
[
n
][
y
][
1
],
0x101
,
0xf
,
0xf
,
1
);
// row_shl1
accm
+=
__builtin_amdgcn_mov_dpp
(
sum4
[
n
][
y
][
2
],
0x102
,
0xf
,
0xf
,
1
);
// row_shl2
accm
+=
__builtin_amdgcn_mov_dpp
(
sum4
[
n
][
y
][
3
],
0x103
,
0xf
,
0xf
,
1
);
// row_shl3
accm
+=
__builtin_amdgcn_mov_dpp
(
accm
,
0x104
,
0xf
,
0xf
,
1
);
// row_shl4
accm
+=
__builtin_amdgcn_mov_dpp
(
accm
,
0x108
,
0xf
,
0xf
,
1
);
// row_shl8
accm
=
__builtin_amdgcn_mov_dpp
(
accm
,
0x11f
,
0xf
,
0xf
,
1
);
// row_shr15
accm
+=
__builtin_amdgcn_mov_dpp
(
accm
,
0x142
,
0xf
,
0xf
,
1
);
// ROW_BCAST15
accm
+=
__builtin_amdgcn_mov_dpp
(
accm
,
0x143
,
0xf
,
0xf
,
1
);
// ROW_BCAST31
sum4
[
n
][
y
][
0
]
=
accm
;
}
}
if
(
threadIdx
.
x
==
63
)
{
scalar_t
biases
[
N
][
YTILE
]
=
{};
if
(
BIAS
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
biases
[
n
][
y
]
=
BIAS
[(
m
+
y
)
%
Bx
+
(
n
%
By
)
*
Bx
];
}
}
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
{
if
(
commitColumn
[
i
])
{
if
(
BIAS
)
sum4
[
n
][
i
][
0
]
+=
__bfloat162float
(
BIAS
[(
m
+
i
)
%
Bx
+
(
n
%
By
)
*
M
]);
C
[
m
+
i
+
n
*
M
]
=
__float2bfloat16
(
sum4
[
n
][
i
][
0
]);
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
if
(
commitColumn
[
y
])
{
sum4
[
n
][
y
][
0
]
+=
__bfloat162float
(
biases
[
n
][
y
]);
C
[
m
+
y
+
n
*
M
]
=
__float2bfloat16
(
sum4
[
n
][
y
][
0
]);
}
}
}
...
...
@@ -880,9 +748,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitK_hf_
(
const
int
K
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
__global__
void
wvSplitK_hf_
(
const
int
K
,
const
int
Kbp
,
const
int
Kap
,
const
int
M
,
const
int
Bx
,
const
int
B
y
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
UNREACHABLE_CODE
...
...
@@ -894,8 +762,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
wvSplitK_hf_big_
(
const
int
K
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
wvSplitK_hf_big_
(
const
int
K
,
const
int
Kbp
,
const
int
Kap
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
constexpr
int
max_lds_len
=
LDS_SIZE
/
2
;
...
...
@@ -966,13 +835,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
#define PCML
#ifndef PCML
for
(
uint32_t
k
=
0
;
k
<
min__
(
K
*
N
,
max_lds_len
)
;
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
k_in
>=
min__
(
K
*
N
,
max_lds_len
))
break
;
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
k
<
min__
(
Kap
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds
((
int
*
)(
&
A
[
k
]),
(
int
*
)(
&
s
[
k
]),
16
,
0
,
0
);
#else
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
#endif
}
__syncthreads
();
#endif
...
...
@@ -987,10 +856,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
?
kFit
:
(
kFit
-
kFit
%
TUC
);
// round up to multiple of TUC
// if (kFit == 0) kFit = TUC;
kFit
=
min__
(
kFit
,
K
);
float
sum
[
N
][
YTILE
];
scalar8
sum4
[
N
][
YTILE
];
kFit
=
min__
(
kFit
,
Kap
);
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
...
...
@@ -1021,15 +887,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
if
constexpr
(
!
use_mfma
)
sum
[
n
][
i
]
=
0
;
else
sum4
[
n
][
i
]
=
{
0
,
0
,
0
,
0
};
bigType
bigA
[
N
][
UNRL
];
bigType
bigB
[
YTILE
][
UNRL
];
float
sum
[
N
][
YTILE
]
=
{};
scalar8
sum4
[
N
][
YTILE
]
=
{};
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
...
...
@@ -1048,18 +908,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
bigType
bigA
[
N
][
UNRL
]
=
{};
bigType
bigB
[
YTILE
][
UNRL
];
#ifdef PCML
if
((
k1
==
0
)
||
(
k1
==
kBase
+
kFit
))
{
// load next chunk of A[] to LDS
if
(
k1
!=
0
)
kBase
+=
kFit
;
__syncthreads
();
for
(
uint32_t
k
=
0
;
k
<
kFit
;
k
+=
THRDS
*
_WvPrGrp
*
A_CHUNK
)
{
uint32_t
kOff
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
kBase
+
kOff
>=
K
)
break
;
if
(
kBase
+
kOff
>=
K
ap
)
break
;
if
(
kOff
>=
kFit
)
break
;
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
uint32_t
k_in
=
kBase
+
n
*
K
+
kOff
;
uint32_t
k_in
=
kBase
+
n
*
K
ap
+
kOff
;
uint32_t
k_ot
=
n
*
kFit
+
kOff
;
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds
((
int
*
)(
&
A
[
k_in
]),
(
int
*
)(
&
s
[
k_ot
]),
16
,
0
,
0
);
#else
*
((
bigType
*
)(
&
s
[
k_ot
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
#endif
}
}
__syncthreads
();
...
...
@@ -1072,11 +940,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
const
scalar_t
*
B_
=
&
B
[(
m
+
0
)
*
K
+
k_
];
for
(
int
b
=
0
;
b
<
YTILE
;
b
++
)
bigB
[
b
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
b
*
K
])));
const
scalar_t
*
B_
=
&
B
[
min__
(
k_
,
K
-
A_CHUNK
)];
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
bigB
[
y
][
k2
].
h8
=
(
loadnt
((
scalar8
*
)(
&
B_
[
min__
(
y
+
m
,
M
-
1
)
*
Kbp
])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
...
...
@@ -1085,17 +951,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
#ifdef PCML
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
-
kBase
+
kFit
*
n
])));
#else
if
(
k_
+
K
*
n
<
32
*
1024
)
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
if
(
k_
+
K
ap
*
n
<
max_lds_len
)
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
ap
*
n
])));
else
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
A
[
k_
+
K
*
n
])));
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
A
[
k_
+
K
ap
*
n
])));
#endif
}
}
...
...
@@ -1103,22 +966,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Do the matrix multiplication in interleaved manner
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
#pragma unroll
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
if
constexpr
(
!
use_mfma
)
#pragma unroll
for
(
uint32_t
b
=
0
;
b
<
A_CHUNK
/
2
;
b
++
)
{
DOT2C
(
sum
[
n
][
y
],
bigA
[
n
][
k2
].
f
[
b
],
bigB
[
y
][
k2
].
f
[
b
])
}
else
#pragma unroll
for
(
uint32_t
b
=
0
;
b
<
A_CHUNK
/
4
;
b
++
)
sum4
[
n
][
y
]
=
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
bigA
[
n
][
k2
].
h4
[
b
],
bigB
[
y
][
k2
].
h4
[
b
],
sum4
[
n
][
y
],
0
,
0
,
0
);
...
...
@@ -1141,40 +995,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if
constexpr
(
!
use_mfma
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
:
"=v"
(
sum
[
n
][
y
])
:
"0"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]),
"v"
(
sum
[
n
][
y
]));
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x118
,
0xf
,
0xf
,
1
);
// row_shr8
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x114
,
0xf
,
0xf
,
1
);
// row_shr4
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x112
,
0xf
,
0xf
,
1
);
// row_shr2
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x111
,
0xf
,
0xf
,
1
);
// row_shr1
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x142
,
0xf
,
0xf
,
1
);
// ROW_BCAST15
sum
[
n
][
y
]
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
],
0x143
,
0xf
,
0xf
,
1
);
// ROW_BCAST31
}
}
if
(
threadIdx
.
x
==
63
)
{
scalar_t
biases
[
N
][
YTILE
]
=
{};
if
(
BIAS
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
biases
[
n
][
y
]
=
BIAS
[(
m
+
y
)
%
Bx
+
(
n
%
By
)
*
Bx
];
}
}
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
{
if
(
commitColumn
[
i
])
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
if
(
commitColumn
[
y
])
{
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
half
>
)
{
if
(
BIAS
)
sum
[
n
][
i
]
+=
__half2float
(
BIAS
[(
m
+
i
)
%
Bx
+
(
n
%
By
)
*
M
]);
sum
[
n
][
y
]
+=
__half2float
(
biases
[
n
][
y
]);
}
else
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
)
{
if
(
BIAS
)
sum
[
n
][
i
]
+=
__bfloat162float
(
BIAS
[(
m
+
i
)
%
Bx
+
(
n
%
By
)
*
M
]);
sum
[
n
][
y
]
+=
__bfloat162float
(
biases
[
n
][
y
]);
}
C
[
m
+
i
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
i
]);
C
[
m
+
y
+
n
*
M
]
=
__float2s
<
scalar_t
>
(
sum
[
n
][
y
]);
}
}
}
...
...
@@ -1185,42 +1037,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
float
accm
=
sum4
[
n
][
y
][
0
];
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
sum4
[
n
][
y
][
1
]),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
sum4
[
n
][
y
][
2
]),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
sum4
[
n
][
y
][
3
]),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
),
"v"
(
accm
));
asm
(
"s_nop 0
\n\t
v_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
:
"=v"
(
accm
)
:
"0"
(
accm
),
"v"
(
accm
),
"v"
(
accm
));
accm
+=
__builtin_amdgcn_mov_dpp
(
sum4
[
n
][
y
][
1
],
0x101
,
0xf
,
0xf
,
1
);
// row_shl1
accm
+=
__builtin_amdgcn_mov_dpp
(
sum4
[
n
][
y
][
2
],
0x102
,
0xf
,
0xf
,
1
);
// row_shl2
accm
+=
__builtin_amdgcn_mov_dpp
(
sum4
[
n
][
y
][
3
],
0x103
,
0xf
,
0xf
,
1
);
// row_shl3
accm
+=
__builtin_amdgcn_mov_dpp
(
accm
,
0x104
,
0xf
,
0xf
,
1
);
// row_shl4
accm
+=
__builtin_amdgcn_mov_dpp
(
accm
,
0x108
,
0xf
,
0xf
,
1
);
// row_shl8
accm
=
__builtin_amdgcn_mov_dpp
(
accm
,
0x11f
,
0xf
,
0xf
,
1
);
// row_shr15
accm
+=
__builtin_amdgcn_mov_dpp
(
accm
,
0x142
,
0xf
,
0xf
,
1
);
// ROW_BCAST15
accm
+=
__builtin_amdgcn_mov_dpp
(
accm
,
0x143
,
0xf
,
0xf
,
1
);
// ROW_BCAST31
sum4
[
n
][
y
][
0
]
=
accm
;
}
}
if
(
threadIdx
.
x
==
63
)
{
scalar_t
biases
[
N
][
YTILE
]
=
{};
if
(
BIAS
)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
biases
[
n
][
y
]
=
BIAS
[(
m
+
y
)
%
Bx
+
(
n
%
By
)
*
Bx
];
}
}
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
YTILE
;
i
++
)
{
if
(
commitColumn
[
i
])
{
if
(
BIAS
)
sum4
[
n
][
i
][
0
]
+=
__bfloat162float
(
BIAS
[(
m
+
i
)
%
Bx
+
(
n
%
By
)
*
M
]);
C
[
m
+
i
+
n
*
M
]
=
__float2bfloat16
(
sum4
[
n
][
i
][
0
]);
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
if
(
commitColumn
[
y
])
{
sum4
[
n
][
y
][
0
]
+=
__bfloat162float
(
biases
[
n
][
y
]);
C
[
m
+
y
+
n
*
M
]
=
__float2bfloat16
(
sum4
[
n
][
y
][
0
]);
}
}
}
...
...
@@ -1244,8 +1092,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitK_hf_big_
(
const
int
K
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
__global__
void
wvSplitK_hf_big_
(
const
int
K
,
const
int
Kbp
,
const
int
Kap
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
...
...
@@ -1272,6 +1121,8 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
auto
M_in
=
in_a
.
size
(
0
);
auto
K_in
=
in_a
.
size
(
1
);
auto
N_in
=
in_b
.
size
(
0
);
auto
Kap_in
=
in_a
.
stride
(
0
);
auto
Kbp_in
=
in_b
.
stride
(
0
);
auto
Bx_in
=
(
in_bias
.
has_value
()
&&
in_bias
->
numel
()
>
0
)
?
(
in_bias
->
sizes
().
size
()
==
2
)
?
in_bias
->
size
(
1
)
:
in_bias
->
size
(
0
)
...
...
@@ -1296,27 +1147,30 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
max_lds_len
=
get_lds_size
()
/
2
;
#define WVSPLITK(_YTILE, _UNRL, _N) \
{ \
dim3 block(64, 16); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else if (K_in * N_in <= max_lds_len * 1.2) \
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else \
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
#define WVSPLITK(_YTILE, _UNRL, _N) \
{ \
dim3 block(64, 16); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((Kbp_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
else if (Kbp_in * N_in <= max_lds_len * 1.2) \
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
else \
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
}
#define WVSPLIT_TILE(_sYT, __N) \
{ \
bool fit_lds = (K_in * N_in <= max_lds_len);
\
bool fit_lds = (K
bp
_in * N_in <= max_lds_len); \
if (_sYT <= 1) \
WVSPLITK(1, 4, __N) \
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
...
...
tests/kernels/quantization/test_rocm_skinny_gemms.py
View file @
7600642e
...
...
@@ -30,15 +30,22 @@ NKM_FACTORS_LLMM1 = [
NKM_FACTORS_WVSPLITK
=
[
# Different batch sizes with key dimensions
(
1
,
16
,
16
),
(
1
,
32
,
16
),
(
1
,
64
,
64
),
(
2
,
256
,
256
),
(
3
,
1024
,
1024
),
(
4
,
4096
,
4096
),
(
4
,
4096
,
4096
+
1
),
(
4
,
4096
+
16
,
4096
),
(
4
,
4096
+
16
,
4096
+
1
),
# Extended K values
(
1
,
9216
,
512
),
(
2
,
10240
,
1024
),
(
4
,
16384
,
8192
),
(
4
,
16384
*
2
,
8192
),
(
4
,
16384
*
2
,
8192
+
1
),
(
4
,
16384
*
2
+
16
,
8192
),
(
4
,
16384
*
2
+
16
,
8192
+
1
),
# Minimum M constraint validation (m >= 8)
(
1
,
64
,
8
),
(
2
,
128
,
8
),
...
...
@@ -180,59 +187,44 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-8
,
rtol
=
1e-2
)
@
pytest
.
mark
.
parametrize
(
"xnorm"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_WVSPLITK
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"only test for rocm"
)
def
test_rocm_wvsplitk_kernel
(
n
,
k
,
m
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
cu_count
=
num_compute_units
()
A
=
torch
.
rand
(
n
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
B
=
torch
.
rand
(
m
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
ref_out
=
torch
.
nn
.
functional
.
linear
(
A
,
B
)
out
=
ops
.
wvSplitK
(
B
,
A
.
view
(
-
1
,
A
.
size
(
-
1
)),
cu_count
)
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-8
,
rtol
=
1e-2
)
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_WVSPLITK
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"only test for rocm"
)
def
test_rocm_wvsplitk_bias1D_kernel
(
n
,
k
,
m
,
dtype
,
seed
):
@
pytest
.
mark
.
parametrize
(
"bias_mode"
,
BIAS_MODES
)
@
pytest
.
mark
.
parametrize
(
"padded_a"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"padded_b"
,
[
False
,
True
])
def
test_rocm_wvsplitk_kernel
(
xnorm
,
n
,
k
,
m
,
dtype
,
seed
,
bias_mode
,
padded_a
,
padded_b
):
torch
.
manual_seed
(
seed
)
cu_count
=
num_compute_units
()
xavier
=
math
.
sqrt
(
2
/
k
)
# normalize to avoid large output-bias deltas
A
=
(
torch
.
rand
(
n
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
)
*
xavier
B
=
(
torch
.
rand
(
m
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
)
*
xavier
BIAS
=
torch
.
rand
(
m
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
ref_out
=
torch
.
nn
.
functional
.
linear
(
A
,
B
,
BIAS
)
out
=
ops
.
wvSplitK
(
B
,
A
.
view
(
-
1
,
A
.
size
(
-
1
)),
cu_count
,
BIAS
)
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-8
,
rtol
=
1e-2
)
xavier
=
(
math
.
sqrt
(
2
/
k
)
if
xnorm
else
1
)
# normalize to avoid large output-bias deltas
A
=
(
torch
.
rand
(
n
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
*
2
-
1
)
*
xavier
B
=
(
torch
.
rand
(
m
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
*
2
-
1
)
*
xavier
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_WVSPLITK
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"only test for rocm"
)
def
test_rocm_wvsplitk_bias2D_kernel
(
n
,
k
,
m
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
cu_count
=
num_compute_units
()
BIAS
=
None
if
bias_mode
==
1
:
BIAS
=
torch
.
rand
(
m
,
dtype
=
dtype
,
device
=
"cuda"
)
*
2
-
1
elif
bias_mode
==
2
:
BIAS
=
torch
.
rand
(
n
,
m
,
dtype
=
dtype
,
device
=
"cuda"
)
*
2
-
1
xavier
=
math
.
sqrt
(
2
/
k
)
# normalize to avoid large output-bias deltas
A
=
(
torch
.
rand
(
n
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
)
*
xavier
B
=
(
torch
.
rand
(
m
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
)
*
xavier
B
IAS
=
torch
.
rand
(
n
,
m
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
if
padded_a
:
A
=
pad_fp8
(
A
)
if
padded_b
:
B
=
pad_fp8
(
B
)
ref_out
=
torch
.
nn
.
functional
.
linear
(
A
,
B
,
BIAS
)
out
=
ops
.
wvSplitK
(
B
,
A
.
view
(
-
1
,
A
.
size
(
-
1
)),
cu_count
,
BIAS
)
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-8
,
rtol
=
1e-2
)
if
xnorm
:
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-3
,
rtol
=
1e-8
)
else
:
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-3
,
rtol
=
1e-2
)
@
pytest
.
mark
.
parametrize
(
"xnorm"
,
[
False
,
True
])
...
...
vllm/model_executor/layers/utils.py
View file @
7600642e
...
...
@@ -191,7 +191,6 @@ def rocm_unquantized_gemm_impl(
and
on_gfx9
()
and
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
and
k
%
8
==
0
and
x
.
is_contiguous
()
)
if
use_skinny
is
not
True
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment