Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
518551b1
Commit
518551b1
authored
Feb 12, 2025
by
mtgu0705
Browse files
fix.
parent
809a0c97
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
33 deletions
+2
-33
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
+2
-33
No files found.
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
View file @
518551b1
...
@@ -20,37 +20,6 @@ using ALayout = Row;
...
@@ -20,37 +20,6 @@ using ALayout = Row;
using
BLayout
=
Col
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
CLayout
=
Row
;
void
preShuffleBuffer
(
const
I4
*
src
,
I4
*
dst
,
int
N
,
int
K
,
int
NXdl
)
{
int
KPack
=
32
;
// int4 -> 32, fp8 -> 16, fp16 -> 8
int
NLane
=
NXdl
;
int
KLane
=
64
/
NLane
;
int
K0
=
K
/
(
KLane
*
KPack
);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int
tempk
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
int
n0
=
n
/
NLane
;
int
n1
=
n
%
NLane
;
int
k0
=
k
/
(
KLane
*
KPack
);
tempk
=
k
%
(
KLane
*
KPack
);
int
k1
=
tempk
/
KPack
;
int
k2
=
tempk
%
KPack
;
int
outputIndex
=
n0
*
KPack
*
NLane
*
KLane
*
K0
+
k0
*
KPack
*
NLane
*
KLane
+
k1
*
KPack
*
NLane
+
n1
*
KPack
+
k2
;
dst
[
outputIndex
]
=
src
[
n
*
K
+
k
];
}
}
}
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
...
@@ -180,9 +149,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -180,9 +149,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
// N -> N0 NLane
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
// N, K -> N0 K0 KLane NLane KPack
int
tempk
;
int
tempk
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
int
n0
=
n
/
NLane
;
int
n0
=
n
/
NLane
;
int
n1
=
n
%
NLane
;
int
n1
=
n
%
NLane
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment