Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
eab497e8
Commit
eab497e8
authored
Nov 15, 2024
by
letaoqin
Browse files
format
parent
1476d7bb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
20 additions
and
439 deletions
+20
-439
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+5
-5
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
...tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
+1
-1
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_traits.hpp
..._fused_moe_general/instances/fused_moegemm_api_traits.hpp
+6
-3
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
...16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
+1
-1
example/ck_tile/16_fused_moe_general/main.cpp
example/ck_tile/16_fused_moe_general/main.cpp
+0
-1
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
+7
-428
No files found.
example/ck_tile/15_fused_moe/main.cpp
View file @
eab497e8
...
@@ -208,7 +208,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -208,7 +208,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
IndexDataType
>
num_sorted_tiles_host
({
1
});
ck_tile
::
HostTensor
<
IndexDataType
>
num_sorted_tiles_host
({
1
});
#if 0
#if 0
#
if 1
#if 1
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host);
ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host);
ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host);
ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host);
...
@@ -217,7 +217,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -217,7 +217,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host);
ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host);
ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host);
ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host);
ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host);
ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host);
#
else
#else
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_host);
...
@@ -226,7 +226,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -226,7 +226,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f}(sd_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f}(topk_weight_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f}(topk_weight_host);
#
endif
#endif
// permute weight
// permute weight
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
,
1
);
...
@@ -266,7 +266,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -266,7 +266,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
std
::
cout
<<
"------- @@@ "
<<
__LINE__
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
"------- @@@ "
<<
__LINE__
<<
std
::
flush
<<
std
::
endl
;
#
if 0
#if 0
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host,
topk_ids_host,
topk_weight_host,
topk_weight_host,
...
@@ -319,7 +319,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -319,7 +319,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
return 1;
return 1;
#
endif
#endif
#endif
#endif
(
void
)
balance
;
(
void
)
balance
;
...
...
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
View file @
eab497e8
...
@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
...
@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
5
12
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
12
8
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
}
// clang-format on
// clang-format on
...
...
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_traits.hpp
View file @
eab497e8
...
@@ -34,11 +34,14 @@ struct fmoe_ // traits, ugly name, only used for internal
...
@@ -34,11 +34,14 @@ struct fmoe_ // traits, ugly name, only used for internal
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
TopkWeightDataType
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
TopkWeightDataType
>
;
using
IndexDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
IndexDataType
>
;
using
IndexDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
IndexDataType
>
;
static
constexpr
ck_tile
::
index_t
BT_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
0
>
{});
// block token(block_m0, block_m1)
static
constexpr
ck_tile
::
index_t
BT_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
0
>
{});
// block token(block_m0, block_m1)
static
constexpr
ck_tile
::
index_t
BI_
=
static
constexpr
ck_tile
::
index_t
BI_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
1
>
{});
// block intermediate (block_n0, block_k1)
BlockTIle_
::
at
(
ck_tile
::
number
<
1
>
{});
// block intermediate (block_n0, block_k1)
static
constexpr
ck_tile
::
index_t
BH_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
2
>
{});
// block hidden(block_k0)
static
constexpr
ck_tile
::
index_t
BH_
=
static
constexpr
ck_tile
::
index_t
BD_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
3
>
{});
// block down(block_n1)
BlockTIle_
::
at
(
ck_tile
::
number
<
2
>
{});
// block hidden(block_k0)
static
constexpr
ck_tile
::
index_t
BD_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
3
>
{});
// block down(block_n1)
using
BlockTile_0
=
ck_tile
::
sequence
<
BT_
,
BI_
,
BH_
>
;
using
BlockTile_0
=
ck_tile
::
sequence
<
BT_
,
BI_
,
BH_
>
;
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
...
...
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
View file @
eab497e8
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
// clang-format off
// clang-format off
template
float
fused_moegemm_
<
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
5
12
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
12
8
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
// clang-format on
example/ck_tile/16_fused_moe_general/main.cpp
View file @
eab497e8
...
@@ -216,7 +216,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -216,7 +216,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
}(
sy_host
);
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
}(
sy_host
);
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
0.0
f
,
1.0
f
}(
topk_weight_host
);
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
0.0
f
,
1.0
f
}(
topk_weight_host
);
// permute weight
// permute weight
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
View file @
eab497e8
...
@@ -66,448 +66,27 @@ struct FusedMoeGemmPipeline_FlatmmGl
...
@@ -66,448 +66,27 @@ struct FusedMoeGemmPipeline_FlatmmGl
}
}
}();
}();
static
constexpr
const
char
*
name
=
"flatmm_
uk
"
;
static
constexpr
const
char
*
name
=
"flatmm_
gl
"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
))
;
return
smem_bridge
;
}
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetACoord
()
{
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetOCoord
()
{
constexpr
auto
o_dist
=
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>();
const
auto
o_coord
=
o_dist
.
calculate_index
();
return
o_coord
;
}
CK_TILE_DEVICE
constexpr
auto
GetNumRowCoords_A
()
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
return
MRepeat
;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE
auto
GetRowCoords_A
(
index_t
base_offset
)
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
auto
base_coord
=
threadIdx
.
x
/
KLans
+
base_offset
;
array
<
index_t
,
MRepeat
>
coords
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLans
;
});
return
coords
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetRowID_A
(
const
ROW_COORDS
coords
,
const
IndexDataType
*
sorted_token_ids_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
index_t
,
n_size
>
row_ids
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
row_ids
.
at
(
i
)
=
sorted_token_ids_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
});
return
row_ids
;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE
auto
GetRowCoords_O
(
index_t
base_offset
)
{
constexpr
index_t
WarpGemmLane_M
=
16
;
// TODO: use 16x16
constexpr
index_t
WarpGemmRepeat_M
=
BlockShape
::
Block_M0
/
WarpGemmLane_M
;
auto
base_coord
=
threadIdx
.
x
%
WarpGemmLane_M
+
base_offset
;
array
<
index_t
,
WarpGemmRepeat_M
>
coords
;
static_for
<
0
,
WarpGemmRepeat_M
,
1
>
{}(
[
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
WarpGemmLane_M
;
});
return
coords
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetWeightScale
(
const
ROW_COORDS
coords
,
const
TopkWeightDataType
*
sorted_weight_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
TopkWeightDataType
,
n_size
>
w
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
w
.
at
(
i
)
=
sorted_weight_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
});
return
w
;
}
CK_TILE_DEVICE
auto
GetRowCoords_O
()
{
constexpr
index_t
NLans
=
BlockShape
::
Block_N1
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
NLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M1
/
MLans
;
auto
base_coord
=
threadIdx
.
x
/
NLans
;
array
<
index_t
,
MRepeat
>
coords
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLans
;
});
return
coords
;
}
/*
struct FusedMoeGemmKargs
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr;
const void* sorted_weight_ptr;
const void* sorted_expert_ids_ptr;
const void* num_sorted_tiles_ptr;
index_t hidden_size; // k
index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
*/
template
<
typename
Karg
>
template
<
typename
Karg
>
CK_TILE_DEVICE
auto
operator
()(
const
Karg
&
kargs
,
CK_TILE_DEVICE
auto
operator
()(
const
Karg
&
kargs
,
CK_TILE_LDS_ADDR
void
*
smem
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
sorted_tile_id
,
index_t
sorted_tile_id
,
index_t
intermediate_tile_id
)
index_t
intermediate_tile_id
)
{
{
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
ignore
=
kargs
;
ck_tile
::
index_t
shared_intermediate_size_0
=
kargs
.
intermediate_size
;
ignore
=
smem
;
// w1 (Down, N size)
ignore
=
sorted_tile_id
;
ck_tile
::
index_t
shared_intermediate_size_1
=
kargs
.
intermediate_size
/
hidden_radio_0
;
ignore
=
intermediate_tile_id
;
index_t
nr_0
=
shared_intermediate_size_0
/
BlockShape
::
Warp_N0
;
// divide N in W
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Warp_K0
;
// divide K in W
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShape
::
Warp_N1
;
index_t
kr_1
=
shared_intermediate_size_1
/
BlockShape
::
Warp_K1
;
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
shared_intermediate_size_1
*
kargs
.
hidden_size
;
// nr*kr*w
index_t
interm_idx_nr
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Nr0
);
// intermediate_tile_id * Block_N / (N in W)
// printf("bid:%d,%d, sorted_tile_id:%d(, intermediate_tile_id:%d, expert_id:%d,
// interm_idx_nr:%d\n", static_cast<int>(blockIdx.x),
// static_cast<int>(blockIdx.y), sorted_tile_id, intermediate_tile_id, expert_id,
// interm_idx_nr);
auto
row_coords_a
=
GetRowCoords_A
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
row_ids_a
=
GetRowID_A
(
row_coords_a
,
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_token_ids_ptr
));
auto
a_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
row_ids_a
[
i
]
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
number
<
row_ids_a
.
size
()
>
{});
auto
a_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
auto
g_win
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
interm_idx_nr
*
kr_0
*
BlockShape
::
Block_W0
;
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
kAlignmentG
>
{},
number
<
1
>
{});
// number<BlockShape::Block_Nr0>{}.fff();
// number<kAlignmentG>{}.zzz();
auto
g_window_
=
make_tile_window_linear_raw
(
g_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
{
0
,
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
return
g_window_
;
}();
// number<decltype(g_win)::NumAccess_NonLinear>{}.rrr2();
auto
g_res
=
g_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
auto
g_coords
=
generate_tuple
([
&
](
auto
i
)
{
return
g_win
.
cached_coords_
[
i
].
get_offset
();
},
number
<
decltype
(
g_win
)
::
NumAccess_NonLinear
>
{});
const
auto
d_win
=
[
&
]()
{
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
interm_idx_nr
*
BlockShape
::
Block_W1
;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
BlockShape
::
Block_W1
),
make_tuple
(
kr_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_W1
,
1
),
number
<
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_window_
=
make_tile_window_linear_raw
(
d_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
number
<
BlockShape
::
Block_Kr1
>
{},
number
<
BlockShape
::
Block_W1
>
{}),
{
0
,
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
return
d_window_
;
}();
auto
d_res
=
d_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
#if 0
auto d_coords = generate_tuple([&](auto i) {
return d_win.cached_coords_[i].get_offset(); },
number<decltype(d_win)::NumAccess_NonLinear>{});
#else
// TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
// block-k=512, block-n=128
// |<----- W_ ----->|
// Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
// y p y y p p y
// 1 2 0(imm)
auto
d_coords
=
[
&
]()
{
constexpr
index_t
Nr_
=
2
;
constexpr
index_t
Nw_
=
4
;
constexpr
index_t
Kr0_
=
4
;
constexpr
index_t
Kr1_
=
4
;
constexpr
index_t
Kl_
=
4
;
constexpr
index_t
Nl_
=
16
;
constexpr
index_t
Kv_
=
8
;
constexpr
index_t
W_
=
Kl_
*
Nl_
*
Kv_
;
constexpr
index_t
num_offsets_
=
Nr_
*
Kr0_
;
index_t
base_os_
=
(
threadIdx
.
x
%
64
)
*
Kv_
+
(
threadIdx
.
x
/
64
)
*
Kr0_
*
Kr1_
*
W_
;
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
i_nr_
=
number
<
i
%
Nr_
>
{};
constexpr
auto
i_kr0_
=
number
<
i
/
Nr_
>
{};
return
i_nr_
*
shared_intermediate_size_1
*
Nw_
*
Nl_
+
i_kr0_
*
Kr1_
*
W_
+
base_os_
;
},
number
<
num_offsets_
>
{});
}();
#endif
auto
o_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
row_ids_a
[
i
]
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_N1
/
kAlignmentO
)
*
kAlignmentO
;
},
number
<
row_ids_a
.
size
()
>
{});
auto
o_flags
=
generate_tuple
([
&
](
auto
i
)
{
return
cmp_lt_to_exec
(
row_ids_a
[
i
],
kargs
.
num_tokens
);
},
number
<
row_ids_a
.
size
()
>
{});
auto
bridge_sst_win
=
[
&
]()
{
constexpr
auto
desc_
=
Policy
::
template
MakeBridgeLdsStoreForUKDesc
<
Problem
>();
constexpr
auto
dist_
=
Policy
::
template
GetUK_0
<
Problem
>().
MakeCBlockDist
();
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
YDataType
*>
(
smem
),
desc_
),
desc_
.
get_lengths
(),
{
0
,
0
},
dist_
);
}();
auto
o_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ODataType
));
auto
row_coords_o
=
GetRowCoords_O
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
w_scale
=
GetWeightScale
(
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
#if 0
printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, "
"interm_idx_nr:%d, coords:a:%d,%d,%d, row_ids_a:%d,%d,%d, (%d)g_coords:%d.%d.%d, "
"o_coords:%d,%d,%d,%d,%d,%d,%d,%d(%d,%d,%d,%d,%d,%d,%d,%d)\n",
static_cast<int>(blockIdx.x),
static_cast<int>(blockIdx.y),
static_cast<int>(threadIdx.x),
sorted_tile_id,
intermediate_tile_id,
expert_id,
interm_idx_nr,
row_coords_a[0],
row_coords_a[1],
row_coords_a[7],
row_ids_a[0],
row_ids_a[1],
row_ids_a[7],
kr_0 * BlockShape::Block_W0,
g_coords[number<0>{}],
g_coords[number<1>{}],
g_coords[number<7>{}],
o_coords[number<0>{}],
o_coords[number<1>{}],
o_coords[number<2>{}],
o_coords[number<3>{}],
o_coords[number<4>{}],
o_coords[number<5>{}],
o_coords[number<6>{}],
o_coords[number<7>{}],
// (row_ids_a[0] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[1] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[2] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[3] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[4] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[5] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[6] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[7] >= kargs.num_tokens ? 1 : 0)
(row_ids_a[0] < kargs.num_tokens && static_cast<index_t>(o_coords[number<0>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[1] < kargs.num_tokens && static_cast<index_t>(o_coords[number<1>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[2] < kargs.num_tokens && static_cast<index_t>(o_coords[number<2>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[3] < kargs.num_tokens && static_cast<index_t>(o_coords[number<3>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[4] < kargs.num_tokens && static_cast<index_t>(o_coords[number<4>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[5] < kargs.num_tokens && static_cast<index_t>(o_coords[number<5>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[6] < kargs.num_tokens && static_cast<index_t>(o_coords[number<6>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[7] < kargs.num_tokens && static_cast<index_t>(o_coords[number<7>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0)
);
#endif
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
acc_0
=
uk_0
(
a_res
,
a_coords
,
g_res
,
g_coords
,
smem
,
kargs
.
hidden_size
,
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
// return ;
//sweep_tile(acc_0,
// [&](auto idx) { typename Problem::GateActivation{}(acc_0(idx), acc_0[idx]); });
sweep_tile
(
acc_0
,
[
&
](
auto
idx0
,
auto
idx1
)
{
fp32x2_t
v_
{
acc_0
(
idx0
),
acc_0
(
idx1
)};
typename
Problem
::
GateActivation
{}(
v_
,
v_
);
acc_0
(
idx0
)
=
v_
.
x
;
acc_0
(
idx1
)
=
v_
.
y
;
},
sequence
<
1
,
2
>
{});
#if 0
printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, "
"interm_idx_nr:%d, coords:a:%d,%d,%d, row_ids_a:%d,%d,%d, (%d)g_coords:%d.%d.%d, bridge_sst_win:%d"
"acc:%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f\n",
static_cast<int>(blockIdx.x),
static_cast<int>(blockIdx.y),
static_cast<int>(threadIdx.x),
sorted_tile_id,
intermediate_tile_id,
expert_id,
interm_idx_nr,
row_coords_a[0],
row_coords_a[1],
row_coords_a[7],
row_ids_a[0],
row_ids_a[1],
row_ids_a[7],
kr_0 * BlockShape::Block_W0,
g_coords[number<0>{}],
g_coords[number<1>{}],
g_coords[number<7>{}],
bridge_sst_win.cached_coords_[number<0>{}].get_offset(),
acc_0.get_thread_buffer()[number<0>{}],
acc_0.get_thread_buffer()[number<1>{}],
acc_0.get_thread_buffer()[number<2>{}],
acc_0.get_thread_buffer()[number<3>{}],
acc_0.get_thread_buffer()[number<4>{}],
acc_0.get_thread_buffer()[number<5>{}],
acc_0.get_thread_buffer()[number<6>{}],
acc_0.get_thread_buffer()[number<7>{}],
acc_0.get_thread_buffer()[number<8 + 0>{}],
acc_0.get_thread_buffer()[number<8 + 1>{}],
acc_0.get_thread_buffer()[number<8 + 2>{}],
acc_0.get_thread_buffer()[number<8 + 3>{}],
acc_0.get_thread_buffer()[number<8 + 4>{}],
acc_0.get_thread_buffer()[number<8 + 5>{}],
acc_0.get_thread_buffer()[number<8 + 6>{}],
acc_0.get_thread_buffer()[number<8 + 7>{}]);
#endif
auto
y_pre
=
cast_tile
<
YDataType
>
(
acc_0
);
store_tile
(
bridge_sst_win
,
y_pre
);
block_sync_lds
();
auto
uk_1
=
Policy
::
template
GetUK_1
<
Problem
>();
uk_1
(
d_res
,
d_coords
,
o_res
,
o_coords
,
o_flags
,
smem
,
kargs
.
hidden_size
,
// total n number
w_scale
,
BlockShape
::
Block_Nr1
*
kr_1
*
BlockShape
::
Block_W1
,
// along N
BlockShape
::
Block_N1
);
// along N
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment