Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4be253ee
Commit
4be253ee
authored
Jan 15, 2025
by
coderfeli
Browse files
revert back to mul and silu
parent
e15c6f2d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
21 deletions
+45
-21
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+40
-16
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+1
-1
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+4
-4
No files found.
example/ck_tile/15_fused_moe/main.cpp
View file @
4be253ee
...
@@ -23,12 +23,21 @@ auto get_elimit<ck_tile::bf16_t>()
...
@@ -23,12 +23,21 @@ auto get_elimit<ck_tile::bf16_t>()
double
atol
=
1e-1
;
double
atol
=
1e-1
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
}
template
<
typename
T
>
// template<typename T>
void
fill
(
T
*
x
,
int
len
,
T
val
)
{
// void cleartail(T * x, int len) {
for
(
int
i
=
0
;
i
<
len
;
i
++
){
// int len_32b = len * sizeof(T) / 4;
x
[
i
]
=
val
;
// uint32_t *x_u32 = reinterpret_cast<uint32_t *>(x);
}
// for(int i = 0; i <len_32b; i++){
}
// x_u32[i] = x_u32[i] & 0xfff0fff0;
// }
// }
// template<typename T>
// void fill(T * x, int len, T val) {
// for(int i = 0; i <len; i++){
// x[i] = val;
// }
// }
// mfma_type, 0:32x32, 1:16x16
// mfma_type, 0:32x32, 1:16x16
// TODO: padding?
// TODO: padding?
template
<
typename
T
>
template
<
typename
T
>
...
@@ -309,15 +318,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -309,15 +318,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
else
if
(
init
==
3
)
else
if
(
init
==
3
)
{
{
// fill((ADataType *)a_host.mData.data(), a_host.size(), ck_tile::type_convert<ADataType>(0.1f));
// fill((GDataType *)g_host.mData.data(), g_host.size(), ck_tile::type_convert<GDataType>(0.1f));
// fill((DDataType *)d_host.mData.data(), d_host.size(), ck_tile::type_convert<DDataType>(0.1f));
// fill((AScaleDataType *)sa_host.mData.data(), sa_host.size(), ck_tile::type_convert<AScaleDataType>(1.f));
// fill((GScaleDataType *)sg_host.mData.data(), sg_host.size(), ck_tile::type_convert<GScaleDataType>(1.f));
// fill((DScaleDataType *)sd_host.mData.data(), sd_host.size(), ck_tile::type_convert<DScaleDataType>(1.f));
// fill((DScaleDataType *)sd_host.mData.data(), sd_host.size(), ck_tile::type_convert<DScaleDataType>(1.f));
// fill((YSmoothScaleDataType *)sy_host.mData.data(), sy_host.size(), ck_tile::type_convert<YSmoothScaleDataType>(1.f));
// fill((TopkWeightDataType *)topk_weight_host.mData.data(), topk_weight_host.size(), ck_tile::type_convert<TopkWeightDataType>(1.f));
ck_tile
::
FillNormalDistribution
<
ADataType
>
{
0.
f
,
.1
f
,
seed
,
true
}(
a_host
);
ck_tile
::
FillNormalDistribution
<
ADataType
>
{
0.
f
,
.1
f
,
seed
,
true
}(
a_host
);
ck_tile
::
FillNormalDistribution
<
GDataType
>
{
0.
f
,
.1
f
,
seed
,
true
}(
g_host
);
ck_tile
::
FillNormalDistribution
<
GDataType
>
{
0.
f
,
.1
f
,
seed
,
true
}(
g_host
);
ck_tile
::
FillNormalDistribution
<
DDataType
>
{
0.
f
,
.1
f
,
seed
,
true
}(
d_host
);
ck_tile
::
FillNormalDistribution
<
DDataType
>
{
0.
f
,
.1
f
,
seed
,
true
}(
d_host
);
...
@@ -326,6 +326,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -326,6 +326,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillNormalDistribution
<
DScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillNormalDistribution
<
DScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillNormalDistribution
<
YSmoothScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillNormalDistribution
<
YSmoothScaleDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillNormalDistribution
<
TopkWeightDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
topk_weight_host
);
ck_tile
::
FillNormalDistribution
<
TopkWeightDataType
>
{
0.
f
,
1.
f
,
seed
,
true
}(
topk_weight_host
);
// cleartail((ADataType *)a_host.mData.data(), a_host.size());
// cleartail((GDataType *)g_host.mData.data(), g_host.size());
// cleartail((DDataType *)d_host.mData.data(), d_host.size());
// a_host.savetxt("a.txt");
// cleartail((AScaleDataType *)sa_host.mData.data(), sa_host.size());
// cleartail((GScaleDataType *)sg_host.mData.data(), sg_host.size());
// cleartail((DScaleDataType *)sd_host.mData.data(), sd_host.size());
// cleartail((DScaleDataType *)sd_host.mData.data(), sd_host.size());
// cleartail((YSmoothScaleDataType *)sy_host.mData.data(), sy_host.size());
// fill((ADataType *)a_host.mData.data(), a_host.size(), ck_tile::type_convert<ADataType>(.1f));
// fill((GDataType *)g_host.mData.data(), g_host.size(), ck_tile::type_convert<GDataType>(.1f));
// fill((DDataType *)d_host.mData.data(), d_host.size(), ck_tile::type_convert<DDataType>(.1f));
// fill((AScaleDataType *)sa_host.mData.data(), sa_host.size(), ck_tile::type_convert<AScaleDataType>(1.f));
// fill((GScaleDataType *)sg_host.mData.data(), sg_host.size(), ck_tile::type_convert<GScaleDataType>(1.f));
// fill((DScaleDataType *)sd_host.mData.data(), sd_host.size(), ck_tile::type_convert<DScaleDataType>(1.f));
// fill((DScaleDataType *)sd_host.mData.data(), sd_host.size(), ck_tile::type_convert<DScaleDataType>(1.f));
// fill((YSmoothScaleDataType *)sy_host.mData.data(), sy_host.size(), ck_tile::type_convert<YSmoothScaleDataType>(1.f));
// fill((TopkWeightDataType *)topk_weight_host.mData.data(), topk_weight_host.size(), ck_tile::type_convert<TopkWeightDataType>(1.f));
// cleartail((TopkWeightDataType *)topk_weight_host.mData.data(), topk_weight_host.size());
}
}
// permute weight
// permute weight
...
@@ -484,7 +508,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -484,7 +508,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts
,
experts
,
block_m
);
block_m
);
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Ge
lu
>
(
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Si
lu
>
(
a_host
,
a_host
,
g_host
,
g_host
,
d_host
,
d_host
,
...
...
include/ck_tile/host/reference/reference_fused_moe.hpp
View file @
4be253ee
...
@@ -157,7 +157,7 @@ void reference_fused_moe(
...
@@ -157,7 +157,7 @@ void reference_fused_moe(
{
{
AccDataType
tmp
;
AccDataType
tmp
;
Activation
{}(
tmp
,
acc_0
(
0
,
i_n
));
Activation
{}(
tmp
,
acc_0
(
0
,
i_n
));
y
(
0
,
i_n
)
=
tmp
+
acc_0
(
0
,
i_n
+
intermediate_size_1
);
// TODO: elementwise mul
y
(
0
,
i_n
)
=
tmp
*
acc_0
(
0
,
i_n
+
intermediate_size_1
);
// TODO: elementwise mul
}
}
}
}
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
4be253ee
...
@@ -380,10 +380,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -380,10 +380,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
constexpr
auto
REPEATS
=
BlockShape
::
Repeat_N0
*
BlockShape
::
Repeat_M0
;
constexpr
auto
REPEATS
=
BlockShape
::
Repeat_N0
*
BlockShape
::
Repeat_M0
;
for
(
auto
i
=
0
;
i
<
REPEATS
;
i
++
)
for
(
auto
i
=
0
;
i
<
REPEATS
;
i
++
)
{
{
acc_0
.
get_thread_buffer
()[
4
*
i
+
0
]
+
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
REPEATS
)
+
0
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
0
]
*
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
REPEATS
)
+
0
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
1
]
+
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
REPEATS
)
+
1
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
1
]
*
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
REPEATS
)
+
1
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
2
]
+
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
REPEATS
)
+
2
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
2
]
*
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
REPEATS
)
+
2
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
3
]
+
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
REPEATS
)
+
3
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
3
]
*
=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
REPEATS
)
+
3
];
}
}
}
}
block_sync_lds
();
block_sync_lds
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment