Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
747dd16c
Commit
747dd16c
authored
Jan 13, 2025
by
coderfeli
Browse files
result correct but strange
parent
14099622
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
100 additions
and
43 deletions
+100
-43
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+56
-18
include/ck_tile/host/fill.hpp
include/ck_tile/host/fill.hpp
+1
-1
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+1
-1
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
.../flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
+30
-12
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+11
-10
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+1
-1
No files found.
example/ck_tile/15_fused_moe/main.cpp
View file @
747dd16c
...
...
@@ -11,19 +11,24 @@
template
<
typename
DataType
>
auto
get_elimit
()
{
double
rtol
=
1
e-2
;
double
atol
=
1
e-2
;
double
rtol
=
2
e-2
;
double
atol
=
2
e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
bf16_t
>
()
{
double
rtol
=
1e-
2
;
double
atol
=
1e-
2
;
double
rtol
=
1e-
1
;
double
atol
=
1e-
1
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
typename
T
>
void
fill
(
T
*
x
,
int
len
,
T
val
)
{
for
(
int
i
=
0
;
i
<
len
;
i
++
){
x
[
i
]
=
val
;
}
}
// mfma_type, 0:32x32, 1:16x16
// TODO: padding?
template
<
typename
T
>
...
...
@@ -133,9 +138,9 @@ auto create_args(int argc, char* argv[])
.
insert
(
"tp"
,
"8"
,
"tensor parallel size"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec_i"
,
"
b
f16"
,
"input precision"
)
.
insert
(
"prec_w"
,
"
b
f16"
,
"weight precision"
)
.
insert
(
"prec_o"
,
"
b
f16"
,
"output precision"
)
.
insert
(
"prec_i"
,
"f
p
16"
,
"input precision"
)
.
insert
(
"prec_w"
,
"f
p
16"
,
"weight precision"
)
.
insert
(
"prec_o"
,
"f
p
16"
,
"output precision"
)
.
insert
(
"prec_st"
,
"auto"
,
"token scale data type. auto will set to fp32"
)
.
insert
(
"prec_sw"
,
"auto"
,
"weight scale data type. auto will set to fp32"
)
.
insert
(
"prec_sq"
,
"auto"
,
"(dynamic) smooth quant data type. auto will set to fp32"
)
...
...
@@ -304,14 +309,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else
if
(
init
==
3
)
{
ck_tile
::
FillConstant
<
ADataType
>
{}(
a_host
);
ck_tile
::
FillConstant
<
GDataType
>
{}(
g_host
);
ck_tile
::
FillConstant
<
DDataType
>
{}(
d_host
);
ck_tile
::
FillConstant
<
AScaleDataType
>
{}(
sa_host
);
ck_tile
::
FillConstant
<
GScaleDataType
>
{}(
sg_host
);
ck_tile
::
FillConstant
<
DScaleDataType
>
{}(
sd_host
);
ck_tile
::
FillConstant
<
YSmoothScaleDataType
>
{}(
sy_host
);
ck_tile
::
FillConstant
<
TopkWeightDataType
>
{}(
topk_weight_host
);
// ck_tile::FillConstant<ADataType>{}(a_host);
// ck_tile::FillStepRange<ADataType>{0.f, 16384.f, 1.f}(a_host);
// for (int i = 0 ; i < tokens; i++){
// for (int j = 0; j < hidden_size; j++) {
// a_host.mData[i * hidden_size + j] = ck_tile::type_convert<ADataType>(float(i+1) * 0.1 + float(i * j % 116) * 0.0012);
// }
// }
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
GDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
g_host
);
ck_tile
::
FillUniformDistribution
<
DDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
d_host
);
ck_tile
::
FillUniformDistribution
<
AScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sa_host
);
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
topk_weight_host
);
// a_host.savetxt("a.txt");
// fill((ADataType *)a_host.mData.data(), a_host.size(), ck_tile::type_convert<ADataType>(0.1f));
// fill((GDataType *)g_host.mData.data(), g_host.size(), ck_tile::type_convert<GDataType>(0.1f));
// fill((DDataType *)d_host.mData.data(), d_host.size(), ck_tile::type_convert<DDataType>(0.1f));
// fill((AScaleDataType *)sa_host.mData.data(), sa_host.size(), ck_tile::type_convert<AScaleDataType>(1.f));
// fill((GScaleDataType *)sg_host.mData.data(), sg_host.size(), ck_tile::type_convert<GScaleDataType>(1.f));
// fill((DScaleDataType *)sd_host.mData.data(), sd_host.size(), ck_tile::type_convert<DScaleDataType>(1.f));
// fill((DScaleDataType *)sd_host.mData.data(), sd_host.size(), ck_tile::type_convert<DScaleDataType>(1.f));
// fill((YSmoothScaleDataType *)sy_host.mData.data(), sy_host.size(), ck_tile::type_convert<YSmoothScaleDataType>(1.f));
// fill((TopkWeightDataType *)topk_weight_host.mData.data(), topk_weight_host.size(), ck_tile::type_convert<TopkWeightDataType>(1.f));
// ck_tile::FillNormalDistribution<ADataType>{.1f, .1f, seed, true}(a_host);
// ck_tile::FillNormalDistribution<GDataType>{.1f, .1f, seed, true}(g_host);
// ck_tile::FillNormalDistribution<DDataType>{.1f, .1f, seed, true}(d_host);
// ck_tile::FillNormalDistribution<AScaleDataType>{1.f, 1.f, seed, true}(sa_host);
// ck_tile::FillNormalDistribution<GScaleDataType>{1.f, 1.f, seed, true}(sg_host);
// ck_tile::FillNormalDistribution<DScaleDataType>{1.f, 1.f, seed, true}(sd_host);
// ck_tile::FillNormalDistribution<YSmoothScaleDataType>{1.f, 1.f, seed, true}(sy_host);
// ck_tile::FillNormalDistribution<TopkWeightDataType>{1.f, 1.f, seed, true}(topk_weight_host);
// ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed, true}(d_host);
// ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed, true}(sd_host);
// ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host);
// ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host);
}
// permute weight
...
...
@@ -498,6 +535,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
o_dev
.
savetxt
(
"gpu-out.txt"
,
"float"
);
o_host
.
savetxt
(
"ref.txt"
,
"float"
);
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
pass
&=
ck_tile
::
check_err
(
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
...
...
@@ -583,7 +621,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
{
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Ge
lu
>
(
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Si
lu
>
(
a_host
,
g_host
,
d_host
,
...
...
include/ck_tile/host/fill.hpp
View file @
747dd16c
...
...
@@ -339,7 +339,7 @@ struct FillStepRange
template
<
typename
T
>
struct
FillConstant
{
T
value_
{
type_convert
<
T
>
(
1.
0
f
)};
T
value_
{
type_convert
<
T
>
(
1.
f
)};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
...
...
include/ck_tile/host/reference/reference_fused_moe.hpp
View file @
747dd16c
...
...
@@ -157,7 +157,7 @@ void reference_fused_moe(
{
AccDataType
tmp
;
Activation
{}(
tmp
,
acc_0
(
0
,
i_n
));
y
(
0
,
i_n
)
=
tmp
*
acc_0
(
0
,
i_n
+
intermediate_size_1
);
// TODO: elementwise mul
y
(
0
,
i_n
)
=
tmp
+
acc_0
(
0
,
i_n
+
intermediate_size_1
);
// TODO: elementwise mul
}
}
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
View file @
747dd16c
...
...
@@ -201,7 +201,6 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
// [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
// [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
// [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
scale_0
]
"v"
(
s0
),
...
...
@@ -217,7 +216,7 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
[
s_execflag_6
]
"s"
(
o_flags
[
number
<
6
>
{}]),
[
s_execflag_7
]
"s"
(
o_flags
[
number
<
7
>
{}])
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"memory"
,
"exec"
,
"m0"
,
"vcc"
,
"scc"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
...
...
@@ -275,14 +274,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x256_
);
#pragma clang diagnostic pop
// clang-format on
if
(
1
)
{
printf
(
"
\n
%d %.1f, %.1f, %.1f, %.1f, %.1f, %.1f, %.1f, %.1f
\n
"
,
threadIdx
.
x
,
type_convert
<
float
>
(
v_debug
.
x
),
type_convert
<
float
>
(
v_debug
.
y
),
type_convert
<
float
>
(
v_debug1
.
x
),
type_convert
<
float
>
(
v_debug1
.
y
),
type_convert
<
float
>
(
v_debug2
.
x
),
type_convert
<
float
>
(
v_debug2
.
y
),
type_convert
<
float
>
(
v_debug3
.
x
),
type_convert
<
float
>
(
v_debug3
.
y
));
// if(threadIdx.x==0) {
// printf("%d\n", threadIdx.x);
// }
if
(
threadIdx
.
x
==
0
)
{
printf
(
"%d
\n
"
,
threadIdx
.
x
);
}
// }
// __syncthreads();
}
};
...
...
@@ -356,6 +355,14 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x256_
register
float
v_c29
asm
(
"v93"
);
register
float
v_c30
asm
(
"v94"
);
register
float
v_c31
asm
(
"v95"
);
register
fp16x2_t
v_debug
asm
(
"v160"
);
register
fp16x2_t
v_debug1
asm
(
"v161"
);
register
fp16x2_t
v_debug2
asm
(
"v162"
);
register
fp16x2_t
v_debug3
asm
(
"v163"
);
register
fp16x2_t
v_debug4
asm
(
"v164"
);
register
fp16x2_t
v_debug5
asm
(
"v165"
);
register
fp16x2_t
v_debug6
asm
(
"v166"
);
register
fp16x2_t
v_debug7
asm
(
"v167"
);
int32_t
nan_hi
=
0x7fff0000
;
int32_t
nan_lo
=
0x00007fff
;
...
...
@@ -424,7 +431,15 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x256_
[
c28
]
"+v"
(
v_c28
),
[
c29
]
"+v"
(
v_c29
),
[
c30
]
"+v"
(
v_c30
),
[
c31
]
"+v"
(
v_c31
)
[
c31
]
"+v"
(
v_c31
),
[
debug0
]
"+v"
(
v_debug
),
[
debug1
]
"+v"
(
v_debug1
),
[
debug2
]
"+v"
(
v_debug2
),
[
debug3
]
"+v"
(
v_debug3
),
[
debug4
]
"+v"
(
v_debug4
),
[
debug5
]
"+v"
(
v_debug5
),
[
debug6
]
"+v"
(
v_debug6
),
[
debug7
]
"+v"
(
v_debug7
)
:
[
sld_a_base
]
"n"
(
0
),
[
shfl_base
]
"n"
(
0
),
...
...
@@ -471,7 +486,7 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x256_
[
s_execflag_6
]
"s"
(
o_flags
[
number
<
6
>
{}]),
[
s_execflag_7
]
"s"
(
o_flags
[
number
<
7
>
{}])
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"memory"
,
"exec"
,
"m0"
,
"vcc"
,
"scc"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
...
...
@@ -529,6 +544,9 @@ struct FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x256_
);
#pragma clang diagnostic pop
// clang-format on
if
(
threadIdx
.
x
==
0
)
{
printf
(
"%d
\n
"
,
threadIdx
.
x
);
}
}
};
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
747dd16c
...
...
@@ -70,10 +70,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_0
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
;
//
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
//
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
//
constexpr index_t smem_bridge =
//
BlockShape::Block_M0 * BlockShape::Block_N0;
return
32768
;
//max(smem_0, max(smem_1, smem_bridge));
}
...
...
@@ -329,7 +329,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
// for(auto i = 0; i <
8
; i++)
// for(auto i = 0; i <
16
; i++)
// {
// if(threadIdx.x==0) {
// printf("%d, %.1f, %.1f, %.1f, %.1f\n",i, acc_0_full.get_thread_buffer()[4 * (i) + 0], acc_0_full.get_thread_buffer()[4 * (i) + 1], acc_0_full.get_thread_buffer()[4 * (i) + 2], acc_0_full.get_thread_buffer()[4 * (i) + 3]);
...
...
@@ -366,12 +366,13 @@ struct FusedMoeGemmPipeline_FlatmmUk
}
if
(
!
IsGateOnly
)
{
for
(
auto
i
=
0
;
i
<
BlockShape
::
Repeat_N0
*
BlockShape
::
Repeat_M0
;
i
++
)
constexpr
auto
REPEATS
=
BlockShape
::
Repeat_N0
*
BlockShape
::
Repeat_M0
;
for
(
auto
i
=
0
;
i
<
REPEATS
;
i
++
)
{
acc_0
.
get_thread_buffer
()[
4
*
i
+
0
]
*
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
0
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
1
]
*
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
1
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
2
]
*
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
2
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
3
]
*
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
3
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
0
]
+
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
REPEATS
)
+
0
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
1
]
+
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
REPEATS
)
+
1
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
2
]
+
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
REPEATS
)
+
2
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
3
]
+
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
REPEATS
)
+
3
];
}
}
block_sync_lds
();
...
...
script/cmake-ck-dev.sh
View file @
747dd16c
...
...
@@ -17,7 +17,7 @@ fi
cmake
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm/
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker
-g -v --save-temps -Wno-gnu-line-marker
"
\
-D
CMAKE_CXX_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker "
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
$GPU_TARGETS
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment