Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7ccdbe16
Commit
7ccdbe16
authored
Nov 13, 2024
by
carlushuang
Browse files
update
parent
e2a318bc
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
560 additions
and
200 deletions
+560
-200
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
+1
-1
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+93
-6
include/ck_tile/core/arch/utility.hpp
include/ck_tile/core/arch/utility.hpp
+24
-0
include/ck_tile/host/fill.hpp
include/ck_tile/host/fill.hpp
+38
-0
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+82
-2
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+29
-20
include/ck_tile/ops/flatmm/pipeline/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.hpp
...peline/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.hpp
+72
-61
include/ck_tile/ops/flatmm/pipeline/uk/flatmm_uk_gfx9_32x512x128_1x4x1_16x16x16.hpp
.../pipeline/uk/flatmm_uk_gfx9_32x512x128_1x4x1_16x16x16.hpp
+51
-56
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+25
-7
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+145
-47
No files found.
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
View file @
7ccdbe16
...
...
@@ -20,7 +20,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
fused_moegemm_
<
t_
>
(
s
,
a
);
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
// clang-format on
return
r
;
...
...
example/ck_tile/15_fused_moe/main.cpp
View file @
7ccdbe16
...
...
@@ -121,8 +121,6 @@ auto create_args(int argc, char* argv[])
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
std
::
cout
<<
"xxxx "
<<
__LINE__
<<
std
::
flush
<<
std
::
endl
;
ck_tile
::
index_t
tokens
=
arg_parser
.
get_int
(
"t"
);
ck_tile
::
index_t
experts
=
arg_parser
.
get_int
(
"e"
);
ck_tile
::
index_t
topk
=
arg_parser
.
get_int
(
"k"
);
...
...
@@ -173,7 +171,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
cout
<<
"["
<<
prec_str
<<
"]"
<<
" t:"
<<
tokens
<<
", e:"
<<
experts
<<
", k:"
<<
topk
<<
", st:"
<<
stride
<<
", hidden:"
<<
hidden_size
<<
", interm:"
<<
intermediate_size
<<
", tp:"
<<
tp
<<
", go:"
<<
gate_only
<<
", q:"
<<
fused_quant
<<
std
::
flush
;
<<
", shared_interm:"
<<
shared_intermediate_size_0
<<
"|"
<<
shared_intermediate_size_1
<<
", go:"
<<
gate_only
<<
", q:"
<<
fused_quant
<<
std
::
flush
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
...
...
@@ -191,7 +191,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify
ck_tile
::
HostTensor
<
ADataType
>
a_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
GDataType
>
g_host
({
experts
,
shared_intermediate_size_0
,
hidden_size
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
({
experts
,
shared_intermediate_size_1
,
hidden_size
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
({
experts
,
hidden_size
,
shared_intermediate_size_1
});
ck_tile
::
HostTensor
<
ODataType
>
o_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
AScaleDataType
>
sa_host
({
tokens
});
ck_tile
::
HostTensor
<
GScaleDataType
>
sg_host
({
shared_intermediate_size_0
});
...
...
@@ -207,6 +207,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
{(
max_num_tokens_padded
+
block_m
-
1
)
/
block_m
});
ck_tile
::
HostTensor
<
IndexDataType
>
num_sorted_tiles_host
({
1
});
#if 1
#if 1
ck_tile
::
FillStepRange
<
ADataType
>
{
-
.5
f
,
.5
f
,
0.01
f
}(
a_host
);
ck_tile
::
FillStepRange
<
GDataType
>
{
-
.5
f
,
.5
f
,
0.01
f
}(
g_host
);
ck_tile
::
FillStepRange
<
DDataType
,
false
>
{
.5
f
,
-
.5
f
,
-
0.01
f
}(
d_host
);
ck_tile
::
FillStepRange
<
AScaleDataType
>
{
0.
f
,
1.
f
,
0.01
f
}(
sa_host
);
ck_tile
::
FillStepRange
<
GScaleDataType
>
{
0.
f
,
1.
f
,
0.01
f
}(
sg_host
);
ck_tile
::
FillStepRange
<
DScaleDataType
>
{
0.
f
,
1.
f
,
0.01
f
}(
sd_host
);
ck_tile
::
FillStepRange
<
YSmoothScaleDataType
>
{
0.
f
,
1.
f
,
0.01
f
}(
sy_host
);
ck_tile
::
FillStepRange
<
TopkWeightDataType
>
{
-
.5
f
,
.5
f
,
0.01
f
}(
topk_weight_host
);
#else
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
.5
f
,
.5
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
GDataType
>
{
-
.5
f
,
.5
f
}(
g_host
);
ck_tile
::
FillUniformDistribution
<
DDataType
>
{
-
.5
f
,
.5
f
}(
d_host
);
...
...
@@ -215,6 +226,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
}(
sd_host
);
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
}(
sy_host
);
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
-
.5
f
,
.5
f
}(
topk_weight_host
);
#endif
// permute weight
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
,
1
);
...
...
@@ -236,6 +248,77 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
topid_unique_gen
<
IndexDataType
>
(
topk_ids_host
.
mData
,
tokens
,
topk
,
experts
,
11913
);
}
#else
a_host
.
loadtxt
(
"../../ater/input_torch.txt"
);
topk_ids_host
.
loadtxt
(
"../../ater/topk_ids_torch.txt"
,
"int"
);
// topk_ids_host.savetxt("topk_ids_2.txt");
topk_weight_host
.
loadtxt
(
"../../ater/topk_weights_torch.txt"
,
"float"
);
std
::
cout
<<
"------- @@@ "
<<
__LINE__
<<
std
::
flush
<<
std
::
endl
;
g_host
.
loadtxt
(
"../../ater/w1_torch.txt"
,
"float"
);
std
::
cout
<<
"------- @@@ "
<<
__LINE__
<<
std
::
flush
<<
std
::
endl
;
d_host
.
loadtxt
(
"../../ater/w2_torch.txt"
,
"float"
);
std
::
cout
<<
"------- @@@ "
<<
__LINE__
<<
std
::
flush
<<
std
::
endl
;
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
,
1
);
std
::
cout
<<
"------- @@@ "
<<
__LINE__
<<
std
::
flush
<<
std
::
endl
;
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
std
::
cout
<<
"------- @@@ "
<<
__LINE__
<<
std
::
flush
<<
std
::
endl
;
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_weight_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
experts
,
block_m
);
std
::
cout
<<
"------- @@@ "
<<
__LINE__
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
sorted_token_ids_host
<<
std
::
endl
;
std
::
cout
<<
num_sorted_tiles_host
<<
std
::
endl
;
std
::
cout
<<
sorted_expert_ids_host
<<
std
::
endl
;
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
a_host
,
g_host
,
d_host
,
sa_host
,
sg_host
,
sd_host
,
sy_host
,
o_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
,
topk_ids_host
,
block_m
,
tokens
,
experts
,
hidden_size
,
shared_intermediate_size_0
,
topk
,
gate_only
);
std
::
cout
<<
"------- >"
<<
std
::
endl
;
std
::
cout
<<
o_host
<<
std
::
endl
;
(
void
)
balance
;
{
ck_tile
::
HostTensor
<
ODataType
>
o_host_torch
({
tokens
,
hidden_size
},
{
stride
,
1
});
o_host_torch
.
loadtxt
(
"../../ater/ref2_torch.txt"
);
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
bool
pass
=
ck_tile
::
check_err
(
o_host
,
o_host_torch
,
std
::
string
(
"OUT-Torch Error: Incorrect results!"
),
rtol
,
atol
);
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
}
return
1
;
#endif
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
...
...
@@ -247,8 +330,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts
,
block_m
);
// std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl;
std
::
cout
<<
sorted_token_ids_host
<<
std
::
endl
;
std
::
cout
<<
num_sorted_tiles_host
<<
std
::
endl
;
std
::
cout
<<
sorted_expert_ids_host
<<
std
::
endl
;
std
::
cout
<<
topk_weight_host
<<
std
::
endl
;
std
::
cout
<<
sorted_weight_host
<<
std
::
endl
;
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
...
...
include/ck_tile/core/arch/utility.hpp
View file @
7ccdbe16
...
...
@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
#endif
}
template
<
typename
T
>
CK_TILE_DEVICE
auto
flag_to_exec
(
const
T
&
v_flag
)
{
static_assert
(
sizeof
(
T
)
==
4
);
// per-thread v_flag store into 2x sgpr
uint32x2_t
exec_flag
;
asm
volatile
(
"v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
:
[
s_exec_flag
]
"=s"
(
exec_flag
)
:
[
v_flag
]
"v"
(
v_flag
));
return
exec_flag
;
}
template
<
typename
X
,
typename
Y
>
CK_TILE_DEVICE
auto
cmp_lt_to_exec
(
const
X
&
x
,
const
Y
&
y
)
{
static_assert
(
sizeof
(
X
)
==
4
&&
sizeof
(
Y
)
==
4
);
// per-thread cmp store into 2x sgpr
uint32x2_t
exec_flag
;
asm
volatile
(
"v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
:
[
s_exec_flag
]
"=s"
(
exec_flag
)
:
[
v_x
]
"v"
(
x
),
[
v_y
]
"v"
(
y
));
return
exec_flag
;
}
}
// namespace ck_tile
include/ck_tile/host/fill.hpp
View file @
7ccdbe16
...
...
@@ -235,6 +235,44 @@ struct FillMonotonicSeq
}
};
template
<
typename
T
,
bool
IsAscending
=
true
>
struct
FillStepRange
{
float
start_value_
{
0
};
float
end_value_
{
3
};
float
step_
{
1
};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
generate
(
first
,
last
,
[
=
,
n
=
start_value_
]()
mutable
{
auto
tmp
=
n
;
n
+=
step_
;
if
constexpr
(
IsAscending
)
{
if
(
n
>
end_value_
)
n
=
start_value_
;
}
else
{
if
(
n
<
end_value_
)
n
=
start_value_
;
}
return
type_convert
<
T
>
(
tmp
);
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillStepRange
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
template
<
typename
T
>
struct
FillConstant
{
...
...
include/ck_tile/host/host_tensor.hpp
View file @
7ccdbe16
...
...
@@ -12,6 +12,7 @@
#include <utility>
#include <vector>
#include <functional>
#include <fstream>
#include "ck_tile/core.hpp"
#include "ck_tile/host/ranges.hpp"
...
...
@@ -589,7 +590,7 @@ struct HostTensor
return
ck_tile
::
span
<
Element
>
{
reinterpret_cast
<
Element
*>
(
data
()),
size
()
*
FromSize
/
ToSize
};
}
#if 1
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensor
<
T
>&
t
)
{
os
<<
t
.
mDesc
;
...
...
@@ -600,11 +601,90 @@ struct HostTensor
{
os
<<
", "
;
}
if
constexpr
(
std
::
is_same_v
<
T
,
bf16_t
>
||
std
::
is_same_v
<
T
,
fp16_t
>
)
{
os
<<
type_convert
<
float
>
(
t
.
mData
[
idx
])
<<
" #### "
;
}
else
{
os
<<
t
.
mData
[
idx
];
}
}
os
<<
"]"
;
return
os
;
}
#endif
// read data from a file, as dtype
// the file could dumped from torch as (targeting tensor is t here)
// numpy.savetxt("f.txt", t.view(-1).numpy())
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
// will output f.txt, each line is a value
// dtype=float or int, internally will cast to real type
void
loadtxt
(
std
::
string
file_name
,
std
::
string
dtype
=
"float"
)
{
std
::
ifstream
file
(
file_name
);
if
(
file
.
is_open
())
{
std
::
string
line
;
index_t
cnt
=
0
;
while
(
std
::
getline
(
file
,
line
))
{
if
(
cnt
>=
static_cast
<
index_t
>
(
mData
.
size
()))
{
throw
std
::
runtime_error
(
std
::
string
(
"data read from file:"
)
+
file_name
+
" is too big"
);
}
if
(
dtype
==
"float"
)
{
mData
[
cnt
]
=
type_convert
<
T
>
(
std
::
stof
(
line
));
}
else
if
(
dtype
==
"int"
||
dtype
==
"int32"
)
{
mData
[
cnt
]
=
type_convert
<
T
>
(
std
::
stoi
(
line
));
}
cnt
++
;
}
file
.
close
();
if
(
cnt
<
static_cast
<
index_t
>
(
mData
.
size
()))
{
std
::
cerr
<<
"Warning! reading from file:"
<<
file_name
<<
", does not match the size of this tensor"
<<
std
::
endl
;
}
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw
std
::
runtime_error
(
std
::
string
(
"unable to open file:"
)
+
file_name
);
}
}
// can save to a txt file and read from torch as:
// torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
void
savetxt
(
std
::
string
file_name
)
{
std
::
ofstream
file
(
file_name
);
if
(
file
.
is_open
())
{
for
(
auto
&
itm
:
mData
)
{
file
<<
itm
<<
std
::
endl
;
}
file
.
close
();
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw
std
::
runtime_error
(
std
::
string
(
"unable to open file:"
)
+
file_name
);
}
}
Descriptor
mDesc
;
Data
mData
;
...
...
include/ck_tile/host/reference/reference_fused_moe.hpp
View file @
7ccdbe16
...
...
@@ -53,12 +53,12 @@ template <typename AccDataType, // you only need to explcitly set this one
typename
IndexDataType
>
void
reference_fused_moe
(
const
ck_tile
::
HostTensor
<
ADataType
>&
a_host
,
// [tokens, hidden_size]
const
ck_tile
::
HostTensor
<
GDataType
>&
g_host
,
// [experts, interme_size, hidden_size]
const
ck_tile
::
HostTensor
<
DDataType
>&
d_host
,
// [experts, hidden_size,
hidden
_size]
const
ck_tile
::
HostTensor
<
GDataType
>&
g_host
,
// [experts, interme_size
_0
, hidden_size]
const
ck_tile
::
HostTensor
<
DDataType
>&
d_host
,
// [experts, hidden_size,
interme
_size
_1
]
const
ck_tile
::
HostTensor
<
AScaleDataType
>&
sa_host
,
// [tokens, 1],
const
ck_tile
::
HostTensor
<
GScaleDataType
>&
sg_host
,
// [experts, 1, interme_size]
const
ck_tile
::
HostTensor
<
GScaleDataType
>&
sg_host
,
// [experts, 1, interme_size
_0
]
const
ck_tile
::
HostTensor
<
DScaleDataType
>&
sd_host
,
// [experts, 1, hidden_size],
const
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>&
sy_host
,
// [experts, 1, interme_size]
const
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>&
sy_host
,
// [experts, 1, interme_size
_0
]
ck_tile
::
HostTensor
<
ODataType
>&
o_host
,
// [tokens, hidden_size]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
sorted_token_ids_host
,
// [max_num_tokens_padded]
const
ck_tile
::
HostTensor
<
TopkWeightDataType
>&
sorted_weight_host
,
// [max_num_tokens_padded]
...
...
@@ -73,7 +73,7 @@ void reference_fused_moe(
ck_tile
::
index_t
tokens
,
ck_tile
::
index_t
experts
,
ck_tile
::
index_t
hidden_size
,
ck_tile
::
index_t
intermediate_size
,
ck_tile
::
index_t
intermediate_size
,
// this size is for gate/up
ck_tile
::
index_t
topk
,
ck_tile
::
index_t
gate_only
)
{
...
...
@@ -81,7 +81,9 @@ void reference_fused_moe(
assert
(
sorted_weight_host
.
get_num_of_dimension
()
==
1
);
assert
(
sorted_expert_ids_host
.
get_num_of_dimension
()
==
1
);
assert
(
num_sorted_tiles_host
.
get_element_size
()
==
1
);
ck_tile
::
index_t
num_sorted_tiles
=
num_sorted_tiles_host
.
mData
[
0
];
ck_tile
::
index_t
num_sorted_tiles
=
num_sorted_tiles_host
.
mData
[
0
]
/
block_m
;
ck_tile
::
index_t
intermediate_size_0
=
intermediate_size
;
ck_tile
::
index_t
intermediate_size_1
=
intermediate_size
/
(
gate_only
?
1
:
2
);
// TODO: better remove this in the future, or modify the token_id value
auto
get_topk_id
=
[
&
](
ck_tile
::
index_t
token_id_
,
ck_tile
::
index_t
expert_id_
)
{
...
...
@@ -90,6 +92,7 @@ void reference_fused_moe(
if
(
token_ids_host
(
token_id_
,
i_
)
==
expert_id_
)
return
i_
;
}
throw
std
::
runtime_error
(
"not correct token/expert pair
\n
"
);
return
-
1
;
// TODO: not correct!!
};
...
...
@@ -108,9 +111,9 @@ void reference_fused_moe(
ck_tile
::
index_t
i_topk
=
get_topk_id
(
i_token
,
i_expert
);
// TODO: ugly
auto
weight
=
sorted_weight_host
.
mData
[
i_flatten
];
ck_tile
::
HostTensor
<
AccDataType
>
acc_0
({
1
,
intermediate_size
});
ck_tile
::
HostTensor
<
AccDataType
>
acc_0
({
1
,
intermediate_size
_0
});
// first gemm
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size
;
i_n
++
)
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size
_0
;
i_n
++
)
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_k
=
0
;
i_k
<
hidden_size
;
i_k
++
)
...
...
@@ -121,32 +124,38 @@ void reference_fused_moe(
acc_0
(
0
,
i_n
)
=
acc
;
}
ck_tile
::
HostTensor
<
AccDataType
>
y
({
1
,
hidden
_size
});
ck_tile
::
HostTensor
<
AccDataType
>
y
({
1
,
intermediate
_size
_1
});
if
(
gate_only
)
{
assert
(
hidden_size
==
intermediate_size
);
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
if
(
intermediate_size_1
!=
intermediate_size_0
)
throw
std
::
runtime_error
(
"intermediate_size not correct, 0:"
+
std
::
to_string
(
intermediate_size_0
)
+
", 1:"
+
std
::
to_string
(
intermediate_size_1
));
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_1
;
i_n
++
)
{
Activation
{}(
y
(
0
,
i_n
),
acc_0
(
0
,
i_n
));
}
}
else
{
assert
(
hidden_size
*
2
==
intermediate_size
);
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
if
(
intermediate_size_1
*
2
!=
intermediate_size_0
)
throw
std
::
runtime_error
(
"intermediate_size not correct, 0:"
+
std
::
to_string
(
intermediate_size_0
)
+
", 1:"
+
std
::
to_string
(
intermediate_size_1
));
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_1
;
i_n
++
)
{
AccDataType
tmp
;
Activation
{}(
tmp
,
acc_0
(
0
,
i_n
));
y
(
0
,
i_n
)
=
tmp
*
acc_0
(
0
,
i_n
+
hidden
_size
);
// TODO: elementwise mul
y
(
0
,
i_n
)
=
tmp
*
acc_0
(
0
,
i_n
+
intermediate
_size
_1
);
// TODO: elementwise mul
}
}
// second gemm
// second gemm
, loop along gemm-n
ck_tile
::
HostTensor
<
AccDataType
>
acc_1
({
1
,
hidden_size
});
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_k
=
0
;
i_k
<
hidden
_size
;
i_k
++
)
for
(
ck_tile
::
index_t
i_k
=
0
;
i_k
<
intermediate
_size
_1
;
i_k
++
)
{
acc
+=
y
(
0
,
i_k
)
*
type_convert
<
AccDataType
>
(
d_host
(
i_expert
,
i_n
,
i_k
));
}
...
...
@@ -165,12 +174,12 @@ void reference_fused_moe(
auto
r
=
[
&
](
auto
i_token
)
{
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
O
DataType
acc
=
type_convert
<
ODataType
>
(
0
);
Acc
DataType
acc
=
type_convert
<
ODataType
>
(
0
);
for
(
ck_tile
::
index_t
i_topk
=
0
;
i_topk
<
topk
;
i_topk
++
)
{
acc
+=
out_topk_tokens
(
i_token
,
i_topk
,
i_n
);
}
o_host
(
i_token
,
i_n
)
=
acc
;
o_host
(
i_token
,
i_n
)
=
type_convert
<
ODataType
>
(
acc
)
;
}
};
make_ParallelTensorFunctor
(
r
,
tokens
)(
std
::
thread
::
hardware_concurrency
());
...
...
include/ck_tile/ops/flatmm/pipeline/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.hpp
View file @
7ccdbe16
...
...
@@ -41,6 +41,17 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
using
BDataType
=
bf16_t
;
using
ODataType
=
bf16_t
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
return
2
*
2
*
4
*
4
*
(
16
*
4
+
4
)
*
sizeof
(
bf16_t
);
}
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
...
...
@@ -48,30 +59,26 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
typename
BCoords
,
typename
ORes
,
typename
OCoords
,
typename
OFlags
,
typename
ScaleTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
BRes
&
res_b
,
const
BCoords
&
cached_coords_b
,
const
ORes
&
res_o
,
const
OCoords
&
cached_coords_o
,
const
OFlags
&
o_flags
,
// this should be in sgpr
CK_TILE_LDS_ADDR
void
*
smem
,
// OWindow& o_window_,
index_t
n
,
// loop along n dim
const
ScaleTensor
&
scale_
,
index_t
stride
_b
,
// stride b is fixed to blockKr * blockW, but still can adjust
index_t
stride
_o
)
index_t
tile_offset
_b
,
// stride b is fixed to blockKr * blockW, but still can adjust
index_t
tile_offset
_o
)
{
// auto cached_coords_b = b_window_.cached_coords_;
// auto res_b =
// b_window_.get_bottom_tensor_view().get_buffer_view().cached_buf_res_; auto
// cached_coords_o = o_window_.cached_coords_; auto res_o =
// o_window_.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
static_assert
(
BCoords
::
size
()
==
8
);
// 8
static_assert
(
OCoords
::
size
()
==
8
);
const
index_t
stride_b_bytes
=
stride
_b
*
sizeof
(
BDataType
);
const
index_t
stride_o_bytes
=
stride
_o
*
sizeof
(
ODataType
);
const
index_t
tile_
stride_b_bytes
=
tile_offset
_b
*
sizeof
(
BDataType
);
const
index_t
tile_
stride_o_bytes
=
tile_offset
_o
*
sizeof
(
ODataType
);
static_assert
(
ScaleTensor
::
size
()
==
2
);
float
s0
=
scale_
[
number
<
0
>
{}];
...
...
@@ -143,6 +150,7 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
asm
volatile
(
";-------------------------------------------------------------
\n
"
" s_mov_b32 s52, 0x07060302 ; v_perm
\n
"
" s_mov_b64 s[38:39], exec ; save current exec
\n
"
" s_mov_b32 s8, %[s_res_o0]
\n
"
" s_mov_b32 s9, %[s_res_o1]
\n
"
" s_mov_b32 s12, %[s_res_b0]
\n
"
...
...
@@ -247,10 +255,9 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond
\n
"
" s_cselect_b32 s86, %[s_
stride
_b], 0
\n
"
" s_cselect_b32 s86, %[s_
tile_os
_b], 0
\n
"
" s_add_u32 s12, s86, s12
\n
"
" s_addc_u32 s13, 0, s13
\n
"
" s_waitcnt vmcnt(24)
\n
"
"L_start%=:
\n
"
" s_waitcnt vmcnt(32)
\n
"
" s_barrier
\n
"
...
...
@@ -517,39 +524,37 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
" ds_read_b32 %[c6], %[v_sfl_sld] offset:4416 + %[shfl_base]
\n
"
" ds_read_b32 %[c7], %[v_sfl_sld] offset:4448 + %[shfl_base]
\n
"
" s_waitcnt lgkmcnt(0)
\n
"
//" s_mov_b64 exec, s[16:17] \n"
// "s_endpgm\n"
" s_mov_b64 exec, %[s_execflag_0]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o0], %[c0], s[8:9]
\n
"
//
" s_mov_b64 exec,
s[36:37]
\n"
//" s_mov_b64 exec, s[18:19] \n"
" s_mov_b64 exec,
%[s_execflag_1]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o1], %[c1], s[8:9]
\n
"
//
" s_mov_b64 exec,
s[36:37]
\n"
//" s_mov_b64 exec, s[20:21] \n"
" s_mov_b64 exec,
%[s_execflag_2]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o2], %[c2], s[8:9]
\n
"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[22:23] \n"
" s_mov_b64 exec, %[s_execflag_3]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o3], %[c3], s[8:9]
\n
"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[24:25] \n"
" s_mov_b64 exec, %[s_execflag_4]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o4], %[c4], s[8:9]
\n
"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[26:27] \n"
" s_mov_b64 exec, %[s_execflag_5]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o5], %[c5], s[8:9]
\n
"
//
" s_mov_b64 exec, s[36:37]
\n"
//
" s_mov_b64 exec,
s[28:29]
\n"
//
"s_endpgm
\n"
" s_mov_b64 exec,
%[s_execflag_6]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o6], %[c6], s[8:9]
\n
"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[30:31] \n"
" s_mov_b64 exec, %[s_execflag_7]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], %[c7], s[8:9]
\n
"
//
" s_mov_b64 exec, s[3
6
:3
7]
\n"
" s_mov_b64 exec, s[3
8
:3
9]
\n
"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k--
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 0
\n
"
" s_cbranch_scc0 L_end%=
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond
\n
"
" s_cselect_b32 s86, %[s_
stride
_b], 0
\n
"
" s_cselect_b32 s86, %[s_
tile_os
_b], 0
\n
"
" s_add_u32 s12, s86, s12
\n
"
" s_addc_u32 s13, 0, s13
\n
"
" s_add_u32 s8, %[s_
stride
_o], s8
\n
"
" s_add_u32 s8, %[s_
tile_os
_o], s8
\n
"
" s_addc_u32 s9, 0, s9
\n
"
//" s_addk_i32 s80, 0x0080 \n"
//" s_cmp_lt_i32 s80, s81 \n"
//" s_cbranch_scc0 label_0E98 \n"
...
...
@@ -817,38 +822,31 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
" ds_read_b32 %[c22], %[v_sfl_sld] offset:4416 + %[shfl_base]
\n
"
" ds_read_b32 %[c23], %[v_sfl_sld] offset:4448 + %[shfl_base]
\n
"
" s_waitcnt lgkmcnt(0)
\n
"
//" s_mov_b64 exec, s[16:17] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], %[c16], s[8:9]
\n
"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[18:19] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], %[c17], s[8:9]
\n
"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[20:21] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], %[c18], s[8:9]
\n
"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[22:23] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], %[c19], s[8:9]
\n
"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[24:25] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], %[c20], s[8:9]
\n
"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[26:27] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], %[c21], s[8:9]
\n
"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[28:29] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], %[c22], s[8:9]
\n
"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[30:31] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], %[c23], s[8:9]
\n
"
//" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_0]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o0], %[c0], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_1]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o1], %[c1], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_2]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o2], %[c2], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_3]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o3], %[c3], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_4]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o4], %[c4], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_5]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o5], %[c5], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_6]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o6], %[c6], s[8:9]
\n
"
" s_mov_b64 exec, %[s_execflag_7]
\n
"
" global_atomic_pk_add_bf16 %[v_os_o7], %[c7], s[8:9]
\n
"
" s_mov_b64 exec, s[38:39]
\n
"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k--
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 0
\n
"
" s_cbranch_scc0 L_end%=
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond
\n
"
" s_cselect_b32 s86, %[s_
stride
_b], 0
\n
"
" s_cselect_b32 s86, %[s_
tile_os
_b], 0
\n
"
" s_add_u32 s12, s86, s12
\n
"
" s_addc_u32 s13, 0, s13
\n
"
" s_add_u32 s8, %[s_
stride
_o], s8
\n
"
" s_add_u32 s8, %[s_
tile_os
_o], s8
\n
"
" s_addc_u32 s9, 0, s9
\n
"
" s_branch L_start%=
\n
"
"L_end%=:
\n
"
...
...
@@ -917,13 +915,22 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
[
s_
stride_o
]
"s"
(
stride_o_bytes
),
[
s_
stride
_b
]
"s"
(
stride_b_bytes
),
[
s_
tile_os_o
]
"s"
(
tile_
stride_o_bytes
),
[
s_
tile_os
_b
]
"s"
(
tile_
stride_b_bytes
),
[
scale_0
]
"v"
(
s0
),
[
scale_1
]
"v"
(
s1
),
[
v_nan_lo
]
"v"
(
nan_lo
),
[
v_nan_hi
]
"v"
(
nan_hi
)
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
[
v_nan_hi
]
"v"
(
nan_hi
),
[
s_execflag_0
]
"s"
(
o_flags
[
number
<
0
>
{}]),
[
s_execflag_1
]
"s"
(
o_flags
[
number
<
1
>
{}]),
[
s_execflag_2
]
"s"
(
o_flags
[
number
<
2
>
{}]),
[
s_execflag_3
]
"s"
(
o_flags
[
number
<
3
>
{}]),
[
s_execflag_4
]
"s"
(
o_flags
[
number
<
4
>
{}]),
[
s_execflag_5
]
"s"
(
o_flags
[
number
<
5
>
{}]),
[
s_execflag_6
]
"s"
(
o_flags
[
number
<
6
>
{}]),
[
s_execflag_7
]
"s"
(
o_flags
[
number
<
7
>
{}])
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
...
...
@@ -953,9 +960,13 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s
16
"
,
"s1
7
"
,
"s1
8
"
,
"s1
9
"
,
"s
20
"
,
"s
21
"
,
"s
2
2"
,
"s
23
"
,
"s
8"
,
"s9"
,
"s12
"
,
"s1
3
"
,
"s1
4
"
,
"s1
5
"
,
"s
38
"
,
"s
39
"
,
"s
5
2"
,
"s
86
"
,
// "s32", "s33",
"v50"
,
"v54"
,
"v55"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
"v92"
,
"v93"
,
"v94"
,
"v95"
,
"v128"
,
"v129"
,
"v130"
,
"v131"
,
"v132"
,
"v133"
,
"v134"
,
"v135"
,
"v136"
,
"v137"
,
"v138"
,
"v139"
,
"v140"
,
"v141"
,
"v142"
,
"v143"
,
"v144"
,
"v145"
,
"v146"
,
"v147"
,
...
...
include/ck_tile/ops/flatmm/pipeline/uk/flatmm_uk_gfx9_32x512x128_1x4x1_16x16x16.hpp
View file @
7ccdbe16
...
...
@@ -241,17 +241,13 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
return
enc_
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
32
*
(
128
+
8
)
*
sizeof
(
bf16_t
);
}
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
#if 0
template <typename AWindow, typename BWindow, typename SmemWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const BWindow& b_window_,
SmemWindow& smem_window_,
index_t k,
index_t stride_a,
index_t stride_b) // stride b is fixed to blockKr * blockW, but still can adjust
#else
template
<
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
>
CK_TILE_DEVICE
auto
operator
()(
const
ARes
&
res_a
,
...
...
@@ -260,9 +256,8 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
const
BCoords
&
cached_coords_b
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
k
,
index_t
stride_a
,
index_t
stride_b
)
// stride b is fixed to blockKr * blockW, but still can adjust
#endif
index_t
tile_offset_a
,
// for each tile, the offset to move for each unroll
index_t
tile_offset_b
)
// for each tile, the offset to move for each unroll
{
static_assert
(
ACoords
::
size
()
==
Block_M
*
Block_K
/
BlockSize
/
2
/*2x per dword*/
);
// 8
static_assert
(
BCoords
::
size
()
==
Repeat_N
);
...
...
@@ -292,8 +287,8 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
make_static_tile_distribution
(
a_block_dstr_encode
));
}();
const
index_t
stride_a_bytes
=
stride
_a
*
sizeof
(
bf16_t
);
const
index_t
stride_b_bytes
=
stride
_b
*
sizeof
(
bf16_t
);
const
index_t
tile_offset_a_bytes
=
tile_offset
_a
*
sizeof
(
bf16_t
);
const
index_t
tile_offset_b_bytes
=
tile_offset
_b
*
sizeof
(
bf16_t
);
const
auto
[
m0_init_value
,
size_per_issue
]
=
get_async_store_smem_info
(
a_sst
);
constexpr
auto
smem_buf_size
=
...
...
@@ -343,9 +338,9 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[smem_sz], %[s_m0_init]
\n
"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move a with cond
\n
"
"s_cselect_b32 s86, %[s_
stride
_a], 0
\n
"
"s_add_u32 s16, s86, s16
\n
"
"s_addc_u32 s17, 0, s17
\n
"
"s_cselect_b32 s86, %[s_
tile_os
_a], 0
; move a with cond
\n
"
"s_add_u32 s16, s86, s16
; move a with cond
\n
"
"s_addc_u32 s17, 0, s17
; move a with cond
\n
"
"; -- prefetch A1
\n
"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, %[s_size_per_issue], m0
\n
"
...
...
@@ -364,9 +359,9 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds
\n
"
"s_add_u32 m0, 0, %[s_m0_init]
\n
"
"s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond
\n
"
"s_cselect_b32 s86, %[s_
stride
_a], 0
\n
"
"s_add_u32 s16, s86, s16
\n
"
"s_addc_u32 s17, 0, s17
\n
"
"s_cselect_b32 s86, %[s_
tile_os
_a], 0
; move a with cond
\n
"
"s_add_u32 s16, s86, s16
; move a with cond
\n
"
"s_addc_u32 s17, 0, s17
; move a with cond
\n
"
"; -- prefetch B0
\n
"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen
\n
"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024
\n
"
...
...
@@ -401,19 +396,19 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072
\n
"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond
\n
"
"s_cselect_b32 s86, %[s_
stride
_b], 0
\n
"
"s_add_u32 s20, s86, s20
\n
"
"s_addc_u32 s21, 0, s21
\n
"
"s_waitcnt vmcnt(40)
\n
"
"s_cselect_b32 s86, %[s_
tile_os
_b], 0
; move b with cond
\n
"
"s_add_u32 s20, s86, s20
; move b with cond
\n
"
"s_addc_u32 s21, 0, s21
; move b with cond
\n
"
"s_waitcnt vmcnt(40)
\n
"
"s_barrier
\n
"
"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]
\n
"
// 1024: N stride, 64 K stride
"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]
\n
"
"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]
\n
"
"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]
\n
"
"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]
\n
"
"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]
\n
"
"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]
\n
"
"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]
\n
"
"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]
\n
"
// 1024: N stride, 64 K stride
"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]
\n
"
"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]
\n
"
"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]
\n
"
"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]
\n
"
"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]
\n
"
"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]
\n
"
"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]
\n
"
"L_start%=:
\n
"
" s_waitcnt vmcnt(24) & lgkmcnt(0)
\n
"
" s_barrier
\n
"
...
...
@@ -601,18 +596,18 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[118:119], v[86:87], %[v_acc_15]
\n
"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[120:121], v[88:89], %[v_acc_15]
\n
"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[122:123], v[90:91], %[v_acc_15]
\n
"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072
\n
"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[124:125], v[92:93], %[v_acc_15]
\n
"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[126:127], v[94:95], %[v_acc_15]
\n
"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 0
\n
"
" s_cbranch_scc0 L_end%=
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond
\n
"
" s_cselect_b32 s86, %[s_
stride
_a], 0
\n
"
" s_cselect_b32 s86, %[s_
tile_os
_a], 0
\n
"
" s_add_u32 s16, s86, s16
\n
"
" s_addc_u32 s17, 0, s17
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond
\n
"
" s_cselect_b32 s86, %[s_
stride
_b], 0
\n
"
" s_cselect_b32 s86, %[s_
tile_os
_b], 0
\n
"
" s_add_u32 s20, s86, s20
\n
"
" s_addc_u32 s21, 0, s21
\n
"
" ;------------------------------------------
\n
"
...
...
@@ -809,11 +804,11 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
" s_cmp_gt_i32 %[s_loop_cnt] 0
\n
"
" s_cbranch_scc0 L_end%=
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond
\n
"
" s_cselect_b32 s86, %[s_
stride
_a], 0
\n
"
" s_cselect_b32 s86, %[s_
tile_os
_a], 0
\n
"
" s_add_u32 s16, s86, s16
\n
"
" s_addc_u32 s17, 0, s17
\n
"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond
\n
"
" s_cselect_b32 s86, %[s_
stride
_b], 0
\n
"
" s_cselect_b32 s86, %[s_
tile_os
_b], 0
\n
"
" s_add_u32 s20, s86, s20
\n
"
" s_addc_u32 s21, 0, s21
\n
"
" s_branch L_start%=
\n
"
...
...
@@ -875,8 +870,8 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
[
sld_os_5
]
"n"
(
sld_os
[
number
<
5
>
{}].
value
),
[
sld_os_6
]
"n"
(
sld_os
[
number
<
6
>
{}].
value
),
[
sld_os_7
]
"n"
(
sld_os
[
number
<
7
>
{}].
value
),
[
s_
stride_a
]
"s"
(
stride
_a_bytes
),
[
s_
stride_b
]
"s"
(
stride
_b_bytes
)
[
s_
tile_os_a
]
"s"
(
tile_offset
_a_bytes
),
[
s_
tile_os_b
]
"s"
(
tile_offset
_b_bytes
)
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
7ccdbe16
...
...
@@ -153,9 +153,25 @@ struct FusedMoeGemmKernel
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
return
""
;
using
S_
=
BlockShape
;
auto
prec_str
=
[
&
]
()
{
std
::
string
base_str
=
_SS_
(
t2s
<
ADataType
>::
name
);
if
(
!
std
::
is_same_v
<
ADataType
,
GDataType
>
)
{
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
GDataType
>::
name
);
}
return
base_str
;
}();
return
_SS_
(
"fused_moe_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
S_
::
Block_M0
)
+
"x"
+
_TS_
(
S_
::
Block_N0
)
+
"x"
+
_TS_
(
S_
::
Block_K0
)
+
"x"
+
_TS_
(
S_
::
Block_N1
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M0
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N0
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_K0
)
+
"_"
+
_TS_
(
S_
::
Warp_M0
)
+
"x"
+
_TS_
(
S_
::
Warp_N0
)
+
"x"
+
_TS_
(
S_
::
Warp_K0
)
+
"_"
+
_SS_
(
Pipeline
::
name
);
#undef _SS_
#undef _TS_
// clang-format on
}
...
...
@@ -199,16 +215,13 @@ struct FusedMoeGemmKernel
constexpr
index_t
block_m
=
BlockShape
::
Block_M0
;
int
max_num_tokens_padded
=
hargs
.
topk
*
hargs
.
num_tokens
+
hargs
.
num_experts
*
block_m
-
hargs
.
topk
;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return
Partitioner
::
GridSize
(
max_num_tokens_padded
,
hargs
.
intermediate_size
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
BlockSize_
);
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
// return max(Pipeline::GetSmemSize(), Epilogue::GetSmemSize());
return
Pipeline
::
GetSmemSize
();
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Pipeline
::
GetSmemSize
();
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
...
...
@@ -222,6 +235,11 @@ struct FusedMoeGemmKernel
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
Partitioner
{}(
num_sorted_tiles
,
kargs
.
intermediate_size
);
// if(threadIdx.x == 0)
// printf("bid:%d,%d, num_sorted_tiles:%d, sorted_tile_id:%d(%d),
// intermediate_tile_id:%d\n", static_cast<int>(blockIdx.x),
// static_cast<int>(blockIdx.y), num_sorted_tiles, sorted_tile_id, sorted_tile_id >=
// num_sorted_tiles? 1 : 0, intermediate_tile_id);
if
(
sorted_tile_id
>=
num_sorted_tiles
)
return
;
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
7ccdbe16
...
...
@@ -66,17 +66,15 @@ struct FusedMoeGemmPipeline_FlatmmUk
}
}();
static
constexpr
const
char
*
name
=
"fused_moe_flatmm_uk"
;
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_A
()
{
return
Policy
::
template
GetSmemSize_A
<
Problem
>();
}
static
constexpr
const
char
*
name
=
"flatmm_uk"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
));
}
// this is the thread-offset along row/col
...
...
@@ -154,7 +152,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
index_t
,
n_size
>
w
;
array
<
TopkWeightDataType
,
n_size
>
w
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
w
.
at
(
i
)
=
sorted_weight_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
});
...
...
@@ -207,34 +205,49 @@ struct FusedMoeGemmPipeline_FlatmmUk
index_t
sorted_tile_id
,
index_t
intermediate_tile_id
)
{
index_t
nr_0
=
kargs
.
intermediate_size
/
BlockShape
::
Block_Nr0
;
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Block_Kr0
;
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShape
::
Block_Nr1
;
// should be same as kr_0
index_t
kr_1
=
kargs
.
intermediate_size
/
BlockShape
::
Block_Kr1
;
// should be same as nr_0
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
ck_tile
::
index_t
shared_intermediate_size_0
=
kargs
.
intermediate_size
;
// w1 (Down, N size)
ck_tile
::
index_t
shared_intermediate_size_1
=
kargs
.
intermediate_size
/
hidden_radio_0
;
index_t
nr_0
=
shared_intermediate_size_0
/
BlockShape
::
Warp_N0
;
// divide N in W
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Warp_K0
;
// divide K in W
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShape
::
Warp_N1
;
index_t
kr_1
=
shared_intermediate_size_1
/
BlockShape
::
Warp_K1
;
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
shared_intermediate_size_1
*
kargs
.
hidden_size
;
// nr*kr*w
index_t
interm_idx_nr
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Nr0
);
// intermediate_tile_id * Block_N / (N in W)
index_t
interm_idx_nr
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Nr0
);
// printf("bid:%d,%d, sorted_tile_id:%d(, intermediate_tile_id:%d, expert_id:%d,
// interm_idx_nr:%d\n", static_cast<int>(blockIdx.x),
// static_cast<int>(blockIdx.y), sorted_tile_id, intermediate_tile_id, expert_id,
// interm_idx_nr);
auto
row_coords_a
=
GetRowCoords_A
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
row_ids_a
=
GetRowID_A
(
row_coords_a
,
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_token_ids_ptr
));
auto
a_coords
=
generate_tuple
([
&
](
auto
i
)
{
return
row_ids_a
[
i
]
*
kargs
.
stride_token
;
},
auto
a_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
row_ids_a
[
i
]
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
number
<
row_ids_a
.
size
()
>
{});
auto
a_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
const
auto
g_win
=
[
&
]()
{
auto
g_win
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
interm_idx_nr
*
kr_0
*
BlockShape
::
Block_W0
;
const
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
...
...
@@ -243,8 +256,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
// number<BlockShape::Block_Nr0>{}.fff();
// number<kAlignmentG>{}.zzz();
const
auto
g_window_
=
make_tile_window_linear
(
g_view_
,
auto
g_window_
=
make_tile_window_linear_raw
(
g_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
...
...
@@ -271,8 +284,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
number
<
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_window_
=
make_tile_window_linear
(
d_view_
,
const
auto
d_window_
=
make_tile_window_linear_raw
(
d_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
number
<
BlockShape
::
Block_Kr1
>
{},
number
<
BlockShape
::
Block_W1
>
{}),
...
...
@@ -309,14 +322,23 @@ struct FusedMoeGemmPipeline_FlatmmUk
constexpr
auto
i_nr_
=
number
<
i
%
Nr_
>
{};
constexpr
auto
i_kr0_
=
number
<
i
/
Nr_
>
{};
return
i_nr_
*
kargs
.
intermediate_size
*
Nw_
*
Nl_
+
i_kr0_
*
Kr1_
*
W_
+
return
i_nr_
*
shared_
intermediate_size
_1
*
Nw_
*
Nl_
+
i_kr0_
*
Kr1_
*
W_
+
base_os_
;
},
number
<
num_offsets_
>
{});
}();
#endif
auto
o_coords
=
generate_tuple
([
&
](
auto
i
)
{
return
row_ids_a
[
i
]
*
kargs
.
stride_token
;
},
number
<
a_coords
.
size
()
>
{});
auto
o_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
row_ids_a
[
i
]
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_N1
/
kAlignmentO
)
*
kAlignmentO
;
},
number
<
row_ids_a
.
size
()
>
{});
auto
o_flags
=
generate_tuple
([
&
](
auto
i
)
{
return
cmp_lt_to_exec
(
row_ids_a
[
i
],
kargs
.
num_tokens
);
},
number
<
row_ids_a
.
size
()
>
{});
auto
bridge_sst_win
=
[
&
]()
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
...
...
@@ -332,7 +354,79 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto
row_coords_o
=
GetRowCoords_O
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
w_scale
=
GetWeightScale
(
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
#if 0
printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, "
"interm_idx_nr:%d, coords:a:%d,%d,%d, row_ids_a:%d,%d,%d, (%d)g_coords:%d.%d.%d, "
"o_coords:%d,%d,%d,%d,%d,%d,%d,%d(%d,%d,%d,%d,%d,%d,%d,%d)\n",
static_cast<int>(blockIdx.x),
static_cast<int>(blockIdx.y),
static_cast<int>(threadIdx.x),
sorted_tile_id,
intermediate_tile_id,
expert_id,
interm_idx_nr,
row_coords_a[0],
row_coords_a[1],
row_coords_a[7],
row_ids_a[0],
row_ids_a[1],
row_ids_a[7],
kr_0 * BlockShape::Block_W0,
g_coords[number<0>{}],
g_coords[number<1>{}],
g_coords[number<7>{}],
o_coords[number<0>{}],
o_coords[number<1>{}],
o_coords[number<2>{}],
o_coords[number<3>{}],
o_coords[number<4>{}],
o_coords[number<5>{}],
o_coords[number<6>{}],
o_coords[number<7>{}],
// (row_ids_a[0] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[1] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[2] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[3] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[4] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[5] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[6] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[7] >= kargs.num_tokens ? 1 : 0)
(row_ids_a[0] < kargs.num_tokens && static_cast<index_t>(o_coords[number<0>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[1] < kargs.num_tokens && static_cast<index_t>(o_coords[number<1>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[2] < kargs.num_tokens && static_cast<index_t>(o_coords[number<2>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[3] < kargs.num_tokens && static_cast<index_t>(o_coords[number<3>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[4] < kargs.num_tokens && static_cast<index_t>(o_coords[number<4>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[5] < kargs.num_tokens && static_cast<index_t>(o_coords[number<5>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[6] < kargs.num_tokens && static_cast<index_t>(o_coords[number<6>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[7] < kargs.num_tokens && static_cast<index_t>(o_coords[number<7>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0)
);
#endif
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
acc_0
=
uk_0
(
a_res
,
a_coords
,
...
...
@@ -340,25 +434,29 @@ struct FusedMoeGemmPipeline_FlatmmUk
g_coords
,
smem
,
kargs
.
hidden_size
,
kargs
.
stride_token
,
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
// return ;
sweep_tile
(
acc_0
,
[
&
](
auto
idx
)
{
typename
Problem
::
GateActivation
{}(
acc_0
(
idx
),
acc_0
[
idx
]);
});
auto
y_pre
=
cast_tile
<
YDataType
>
(
acc_0
);
store_tile
(
bridge_sst_win
,
y_pre
);
block_sync_lds
();
auto
uk_1
=
Policy
::
template
GetUK_1
<
Problem
>();
uk_1
(
d_res
,
d_coords
,
o_res
,
o_coords
,
o_flags
,
smem
,
kargs
.
hidden_size
,
kargs
.
hidden_size
,
// total n number
w_scale
,
BlockShape
::
Block_
Kr0
*
BlockShape
::
Block_W
0
,
kargs
.
stride_token
);
BlockShape
::
Block_
Nr1
*
kr_1
*
BlockShape
::
Block_W
1
,
// along N
BlockShape
::
Block_N1
);
// along N
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment