Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
339a674b
Commit
339a674b
authored
Jan 11, 2025
by
shengnxu
Browse files
current status: single WG, memory out of bound
parent
5d00b37e
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
920 additions
and
683 deletions
+920
-683
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
+4
-4
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
+3
-3
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
+3
-3
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+5
-5
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
...ps/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
+7
-2
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
...flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
+145
-22
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
+35
-33
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
+3
-557
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_3.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_3.inc
+590
-0
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
...lock/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
+33
-33
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+13
-2
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
+79
-19
No files found.
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
View file @
339a674b
...
@@ -19,14 +19,14 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
...
@@ -19,14 +19,14 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
//
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
//
r = fused_moegemm_<t_>(s, a);
}
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
{
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
//
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
//
r = fused_moegemm_<t_>(s, a);
}
}
else
if
(
t
.
prec_i
==
"int8"
&&
t
.
prec_w
==
"int8"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
else
if
(
t
.
prec_i
==
"int8"
&&
t
.
prec_w
==
"int8"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
View file @
339a674b
...
@@ -7,8 +7,8 @@
...
@@ -7,8 +7,8 @@
#include "fused_moegemm_api_internal.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
// clang-format off
template
float
fused_moegemm_
<
//
template float fused_moegemm_<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
//
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
//
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
// clang-format on
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
View file @
339a674b
...
@@ -7,8 +7,8 @@
...
@@ -7,8 +7,8 @@
#include "fused_moegemm_api_internal.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
// clang-format off
template
float
fused_moegemm_
<
//
template float fused_moegemm_<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
//
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
//
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
// clang-format on
example/ck_tile/15_fused_moe/main.cpp
View file @
339a674b
...
@@ -87,11 +87,11 @@ void topid_unique_gen(
...
@@ -87,11 +87,11 @@ void topid_unique_gen(
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"t"
,
"
128
"
,
"num input tokens"
)
arg_parser
.
insert
(
"t"
,
"
32
"
,
"num input tokens"
)
.
insert
(
"e"
,
"
32
"
,
"num of experts"
)
.
insert
(
"e"
,
"
1
"
,
"num of experts"
)
.
insert
(
"k"
,
"
5
"
,
"topk"
)
.
insert
(
"k"
,
"
1
"
,
"topk"
)
.
insert
(
"h"
,
"
8192
"
,
"hidden_size of this model"
)
.
insert
(
"h"
,
"
256
"
,
"hidden_size of this model"
)
.
insert
(
"i"
,
"
8192
"
,
"intermediate_size between 2 gemms of FFN"
)
.
insert
(
"i"
,
"
4096
"
,
"intermediate_size between 2 gemms of FFN"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to hidden_size"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to hidden_size"
)
.
insert
(
"bm"
,
"32"
,
"blocking factor for sorted tokens"
)
.
insert
(
"bm"
,
"32"
,
"blocking factor for sorted tokens"
)
.
insert
(
"tp"
,
"8"
,
"tensor parallel size"
)
.
insert
(
"tp"
,
"8"
,
"tensor parallel size"
)
...
...
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
View file @
339a674b
...
@@ -242,6 +242,7 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
...
@@ -242,6 +242,7 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
{
{
using
ADataType
=
int8_t
;
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
AScaleDataType
=
float
;
// TODO: need paired with tile_window_linear!
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// TODO: need call init_raw() before call this function!
...
@@ -258,7 +259,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
...
@@ -258,7 +259,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
CK_TILE_LDS_ADDR
void
*
smem
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
k
,
index_t
k
,
index_t
tile_offset_a
,
// for each tile, the offset to move for each unroll
index_t
tile_offset_a
,
// for each tile, the offset to move for each unroll
index_t
tile_offset_b
)
// for each tile, the offset to move for each unroll
index_t
tile_offset_b
,
index_t
a_bound_
)
// for each tile, the offset to move for each unroll
{
{
static_assert
(
ACoords
::
size
()
==
Block_M
*
Block_K
/
BlockSize
/
4
/*2x per dword*/
);
// 8
static_assert
(
ACoords
::
size
()
==
Block_M
*
Block_K
/
BlockSize
/
4
/*2x per dword*/
);
// 8
static_assert
(
BCoords
::
size
()
==
Repeat_N
);
static_assert
(
BCoords
::
size
()
==
Repeat_N
);
...
@@ -449,7 +451,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
...
@@ -449,7 +451,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
[
c62
]
"+v"
(
v_z62
),
[
c62
]
"+v"
(
v_z62
),
[
c63
]
"+v"
(
v_z63
),
[
c63
]
"+v"
(
v_z63
),
[
s_mem_
]
"+r"
(
smem
)
[
s_mem_
]
"+r"
(
smem
)
:
[
a_scale0
]
"v"
(
a_scale_
[
0
]),
:
[
a_bound
]
"s"
(
static_cast
<
int
>
(
a_bound_
*
sizeof
(
ADataType
))),
// [a_scale_bound]"s"(a_scale_bound_ * sizeof(AScaleDataType)),
[
a_scale0
]
"v"
(
a_scale_
[
0
]),
[
a_scale1
]
"v"
(
a_scale_
[
1
]),
[
a_scale1
]
"v"
(
a_scale_
[
1
]),
[
gq_scale0
]
"v"
(
gq_scale_
[
0
]),
[
gq_scale0
]
"v"
(
gq_scale_
[
0
]),
[
gq_scale1
]
"v"
(
gq_scale_
[
1
]),
[
gq_scale1
]
"v"
(
gq_scale_
[
1
]),
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
View file @
339a674b
...
@@ -81,6 +81,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -81,6 +81,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
template
<
template
<
// typename DQRes,
// typename DQRes,
// typename BRes,
// typename BRes,
typename
Tokenids
,
typename
DQCoords
,
typename
DQCoords
,
typename
BCoords
,
typename
BCoords
,
typename
ORes
,
typename
ORes
,
...
@@ -92,6 +93,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -92,6 +93,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
operator
()(
operator
()(
// const DQRes& res_dq,
// const DQRes& res_dq,
// const BRes& res_b,
// const BRes& res_b,
const
Tokenids
&
token_id_
,
const
DQCoords
&
cached_coords_dq
,
const
DQCoords
&
cached_coords_dq
,
const
BCoords
&
cached_coords_b
,
const
BCoords
&
cached_coords_b
,
const
ORes
&
res_o
,
const
ORes
&
res_o
,
...
@@ -108,7 +110,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -108,7 +110,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
{
{
static_assert
(
BCoords
::
size
()
==
4
);
// 8
static_assert
(
BCoords
::
size
()
==
4
);
// 8
static_assert
(
OCoords
::
size
()
==
8
);
static_assert
(
OCoords
::
size
()
==
8
);
const
index_t
tile_stride_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
const
index_t
tile_stride_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
const
index_t
tile_offset_half_b_bytes
=
tile_offset_half_b
*
sizeof
(
BDataType
);
const
index_t
tile_offset_half_b_bytes
=
tile_offset_half_b
*
sizeof
(
BDataType
);
const
index_t
tile_stride_o_bytes
=
tile_offset_o
*
sizeof
(
ODataType
);
const
index_t
tile_stride_o_bytes
=
tile_offset_o
*
sizeof
(
ODataType
);
...
@@ -155,8 +156,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -155,8 +156,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc"
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc"
#undef CK_TILE_FLATMM_UK_MFMA
#undef CK_TILE_FLATMM_UK_MFMA
:
[
smem_
]
"+r"
(
smem
)
,
:
[
smem_
]
"+r"
(
smem
)
[
s_loop_cnt
]
"+s"
(
loop_cnt
)
//
[s_loop_cnt]"+s"(loop_cnt)
:
[
sld_a_base
]
"n"
(
0
),
:
[
sld_a_base
]
"n"
(
0
),
// [shfl_base]"n"(0),
// [shfl_base]"n"(0),
// [v_sld_y_os]"v"(sld_y_os),
// [v_sld_y_os]"v"(sld_y_os),
...
@@ -164,8 +165,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -164,8 +165,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [v_sfl_sst]"v"(sfl_sst),
// [v_sfl_sst]"v"(sfl_sst),
[
smq_scale0
]
"s"
(
smq_scale_
[
0
]),
[
smq_scale0
]
"s"
(
smq_scale_
[
0
]),
[
smq_scale1
]
"s"
(
smq_scale_
[
1
]),
[
smq_scale1
]
"s"
(
smq_scale_
[
1
]),
[
s_res_o0
]
"s"
(
res_o
[
0
]),
//
[s_res_o0]"s"(res_o[0]),
[
s_res_o1
]
"s"
(
res_o
[
1
]),
//
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
//[s_res_o3]"s"(res_o[3]),
[
v_os_dq
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_dq
*
sizeof
(
DScaleDataType
))),
[
v_os_dq
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_dq
*
sizeof
(
DScaleDataType
))),
...
@@ -184,19 +185,19 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -184,19 +185,19 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_b_half
]
"s"
(
tile_offset_half_b_bytes
),
[
s_tile_os_b_half
]
"s"
(
tile_offset_half_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
s_tile_os_dq
]
"s"
(
tile_stride_dq_bytes
)
,
[
s_tile_os_dq
]
"s"
(
tile_stride_dq_bytes
)
// [scale_0]"v"(s0),
// [scale_0]"v"(s0),
// [scale_1]"v"(s1),
// [scale_1]"v"(s1),
// [v_nan_lo]"v"(nan_lo),
// [v_nan_lo]"v"(nan_lo),
// [v_nan_hi]"v"(nan_hi),
// [v_nan_hi]"v"(nan_hi),
[
s_execflag_0
]
"s"
(
o_flags
[
number
<
0
>
{}]),
//
[s_execflag_0]"s"(o_flags[number<0>{}]),
[
s_execflag_1
]
"s"
(
o_flags
[
number
<
1
>
{}]),
//
[s_execflag_1]"s"(o_flags[number<1>{}]),
[
s_execflag_2
]
"s"
(
o_flags
[
number
<
2
>
{}]),
//
[s_execflag_2]"s"(o_flags[number<2>{}]),
[
s_execflag_3
]
"s"
(
o_flags
[
number
<
3
>
{}]),
//
[s_execflag_3]"s"(o_flags[number<3>{}]),
[
s_execflag_4
]
"s"
(
o_flags
[
number
<
4
>
{}]),
//
[s_execflag_4]"s"(o_flags[number<4>{}]),
[
s_execflag_5
]
"s"
(
o_flags
[
number
<
5
>
{}]),
//
[s_execflag_5]"s"(o_flags[number<5>{}]),
[
s_execflag_6
]
"s"
(
o_flags
[
number
<
6
>
{}]),
//
[s_execflag_6]"s"(o_flags[number<6>{}]),
[
s_execflag_7
]
"s"
(
o_flags
[
number
<
7
>
{}])
//
[s_execflag_7]"s"(o_flags[number<7>{}])
:
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
...
@@ -228,7 +229,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -228,7 +229,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s6"
,
"s7"
,
"s40"
,
"s41"
,
"s42"
,
"s43"
,
"s44"
,
"s45"
,
"s6"
,
"s7"
,
"s20"
,
"s21"
,
"s22"
,
"s23"
,
"s24"
,
"s25"
,
"s26"
,
"s27"
,
"s28"
,
"s29"
,
"s30"
,
"s31"
,
"s34"
,
"s35"
,
"s38"
,
"s39"
,
"s40"
,
"s41"
,
"s42"
,
"s43"
,
"s44"
,
"s45"
,
"s46"
,
"s47"
,
"s48"
,
"s49"
,
"s50"
,
"s51"
,
"s52"
,
"s53"
,
"s54"
,
"s46"
,
"s47"
,
"s48"
,
"s49"
,
"s50"
,
"s51"
,
"s52"
,
"s53"
,
"s54"
,
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
...
@@ -260,12 +263,14 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -260,12 +263,14 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"v245"
,
"v246"
,
"v247"
,
"v248"
,
"v249"
,
"v250"
,
"v251"
,
"v252"
,
"v245"
,
"v246"
,
"v247"
,
"v248"
,
"v249"
,
"v250"
,
"v251"
,
"v252"
,
"v253"
,
"v254"
,
"v255"
"v253"
,
"v254"
,
"v255"
);
);
if
(
hipBlockIdx_x
==
0
&&
hipBlockIdx_y
==
0
&&
hipBlockIdx_z
==
0
&&
hipThreadIdx_x
==
5
)
{
printf
(
"
\n
sn0 done
\n
"
);
}
// if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0)
// {
// // printf("\n xyz%x,%x,%x,thread idx:%xsn1 done\n",blockIdx.x, blockIdx.y, blockIdx.z ,threadIdx.x );
// printf("\n sn1 done\n");
// }
// return;
asm
volatile
(
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc"
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc"
...
@@ -288,6 +293,123 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
...
@@ -288,6 +293,123 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_b_half
]
"s"
(
tile_offset_half_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
s_tile_os_dq
]
"s"
(
tile_stride_dq_bytes
),
[
scale_0
]
"v"
(
s0
),
[
scale_1
]
"v"
(
s1
)
// [v_nan_lo]"v"(nan_lo),
// [v_nan_hi]"v"(nan_hi),
// [s_execflag_0]"s"(o_flags[number<0>{}]),
// [s_execflag_1]"s"(o_flags[number<1>{}]),
// [s_execflag_2]"s"(o_flags[number<2>{}]),
// [s_execflag_3]"s"(o_flags[number<3>{}]),
// [s_execflag_4]"s"(o_flags[number<4>{}]),
// [s_execflag_5]"s"(o_flags[number<5>{}]),
// [s_execflag_6]"s"(o_flags[number<6>{}]),
// [s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
"a40"
,
"a41"
,
"a42"
,
"a43"
,
"a44"
,
"a45"
,
"a46"
,
"a47"
,
"a48"
,
"a49"
,
"a50"
,
"a51"
,
"a52"
,
"a53"
,
"a54"
,
"a55"
,
"a56"
,
"a57"
,
"a58"
,
"a59"
,
"a60"
,
"a61"
,
"a62"
,
"a63"
,
"a64"
,
"a65"
,
"a66"
,
"a67"
,
"a68"
,
"a69"
,
"a70"
,
"a71"
,
"a72"
,
"a73"
,
"a74"
,
"a75"
,
"a76"
,
"a77"
,
"a78"
,
"a79"
,
"a80"
,
"a81"
,
"a82"
,
"a83"
,
"a84"
,
"a85"
,
"a86"
,
"a87"
,
"a88"
,
"a89"
,
"a90"
,
"a91"
,
"a92"
,
"a93"
,
"a94"
,
"a95"
,
"a96"
,
"a97"
,
"a98"
,
"a99"
,
"a100"
,
"a101"
,
"a102"
,
"a103"
,
"a104"
,
"a105"
,
"a106"
,
"a107"
,
"a108"
,
"a109"
,
"a110"
,
"a111"
,
"a112"
,
"a113"
,
"a114"
,
"a115"
,
"a116"
,
"a117"
,
"a118"
,
"a119"
,
"a120"
,
"a121"
,
"a122"
,
"a123"
,
"a124"
,
"a125"
,
"a126"
,
"a127"
,
"a128"
,
"a129"
,
"a130"
,
"a131"
,
"a132"
,
"a133"
,
"a134"
,
"a135"
,
"a136"
,
"a137"
,
"a138"
,
"a139"
,
"a140"
,
"a141"
,
"a142"
,
"a143"
,
"a144"
,
"a145"
,
"a146"
,
"a147"
,
"a148"
,
"a149"
,
"a150"
,
"a151"
,
"a152"
,
"a153"
,
"a154"
,
"a155"
,
"a156"
,
"a157"
,
"a158"
,
"a159"
,
"a160"
,
"a161"
,
"a162"
,
"a163"
,
"a164"
,
"a165"
,
"a166"
,
"a167"
,
"a168"
,
"a169"
,
"a170"
,
"a171"
,
"a172"
,
"a173"
,
"a174"
,
"a175"
,
"a176"
,
"a177"
,
"a178"
,
"a179"
,
"a180"
,
"a181"
,
"a182"
,
"a183"
,
"a184"
,
"a185"
,
"a186"
,
"a187"
,
"a188"
,
"a189"
,
"a190"
,
"a191"
,
"a192"
,
"a193"
,
"a194"
,
"a195"
,
"a196"
,
"a197"
,
"a198"
,
"a199"
,
"a200"
,
"a201"
,
"a202"
,
"a203"
,
"a204"
,
"a205"
,
"a206"
,
"a207"
,
"a208"
,
"a209"
,
"a210"
,
"a211"
,
"a212"
,
"a213"
,
"a214"
,
"a215"
,
"a216"
,
"a217"
,
"a218"
,
"a219"
,
"a220"
,
"a221"
,
"a222"
,
"a223"
,
"a224"
,
"a225"
,
"a226"
,
"a227"
,
"a228"
,
"a229"
,
"a230"
,
"a231"
,
"a232"
,
"a233"
,
"a234"
,
"a235"
,
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s6"
,
"s7"
,
"s20"
,
"s21"
,
"s22"
,
"s23"
,
"s24"
,
"s25"
,
"s26"
,
"s27"
,
"s28"
,
"s29"
,
"s30"
,
"s31"
,
"s38"
,
"s39"
,
"s34"
,
"s35"
,
"s40"
,
"s41"
,
"s42"
,
"s43"
,
"s44"
,
"s45"
,
"s46"
,
"s47"
,
"s48"
,
"s49"
,
"s50"
,
"s51"
,
"s52"
,
"s53"
,
"s54"
,
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v12"
,
"v13"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v56"
,
"v57"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
"v92"
,
"v93"
,
"v94"
,
"v95"
,
"v96"
,
"v97"
,
"v98"
,
"v99"
,
"v100"
,
"v101"
,
"v102"
,
"v103"
,
"v104"
,
"v105"
,
"v106"
,
"v107"
,
"v108"
,
"v109"
,
"v110"
,
"v111"
,
"v112"
,
"v113"
,
"v114"
,
"v115"
,
"v116"
,
"v117"
,
"v118"
,
"v119"
,
"v120"
,
"v121"
,
"v122"
,
"v123"
,
"v124"
,
"v125"
,
"v126"
,
"v127"
,
"v128"
,
"v129"
,
"v130"
,
"v131"
,
"v132"
,
"v133"
,
"v134"
,
"v135"
,
"v136"
,
"v137"
,
"v138"
,
"v139"
,
"v140"
,
"v141"
,
"v142"
,
"v143"
,
"v144"
,
"v145"
,
"v146"
,
"v147"
,
"v148"
,
"v149"
,
"v150"
,
"v151"
,
"v152"
,
"v153"
,
"v154"
,
"v155"
,
"v156"
,
"v157"
,
"v158"
,
"v159"
,
"v160"
,
"v161"
,
"v162"
,
"v163"
,
"v164"
,
"v165"
,
"v166"
,
"v167"
,
"v168"
,
"v169"
,
"v170"
,
"v171"
,
"v172"
,
"v173"
,
"v174"
,
"v175"
,
"v176"
,
"v177"
,
"v178"
,
"v179"
,
"v180"
,
"v181"
,
"v182"
,
"v183"
,
"v184"
,
"v185"
,
"v186"
,
"v187"
,
"v188"
,
"v189"
,
"v190"
,
"v191"
,
"v192"
,
"v193"
,
"v194"
,
"v195"
,
"v196"
,
"v197"
,
"v198"
,
"v199"
,
"v200"
,
"v201"
,
"v202"
,
"v203"
,
"v204"
,
"v205"
,
"v206"
,
"v207"
,
"v208"
,
"v209"
,
"v210"
,
"v211"
,
"v212"
,
"v213"
,
"v214"
,
"v215"
,
"v216"
,
"v217"
,
"v218"
,
"v219"
,
"v220"
,
"v221"
,
"v222"
,
"v223"
,
"v224"
,
"v225"
,
"v226"
,
"v227"
,
"v228"
,
"v229"
,
"v230"
,
"v231"
,
"v232"
,
"v233"
,
"v234"
,
"v235"
,
"v236"
,
"v237"
,
"v238"
,
"v239"
,
"v240"
,
"v241"
,
"v242"
,
"v243"
,
"v244"
,
"v245"
,
"v246"
,
"v247"
,
"v248"
,
"v249"
,
"v250"
,
"v251"
,
"v252"
,
"v253"
,
"v254"
,
"v255"
);
// if(hipBlockIdx_x == 0 && hipBlockIdx_y == 1 && hipBlockIdx_z == 0 &&
// hipThreadIdx_x == 0)
// {
// printf("\n sn2 done\n");
// }
return
;
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_INT8
#include "uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_3.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:
[
smem_
]
"+r"
(
smem
),
[
s_loop_cnt
]
"+s"
(
loop_cnt
)
:
[
sld_a_base
]
"n"
(
0
),
[
s_res_o0
]
"s"
(
res_o
[
0
]),
[
s_res_o1
]
"s"
(
res_o
[
1
]),
[
v_os_dq
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_dq
*
sizeof
(
DScaleDataType
))),
[
v_os_o0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
0
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
1
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
2
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
3
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
4
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
5
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
6
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
7
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_b0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
0
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
s_token_id0
]
"s"
(
token_id_
[
number
<
0
>
{}]),
[
s_token_id1
]
"s"
(
token_id_
[
number
<
1
>
{}]),
[
s_token_id2
]
"s"
(
token_id_
[
number
<
2
>
{}]),
[
s_token_id3
]
"s"
(
token_id_
[
number
<
3
>
{}]),
[
s_token_id4
]
"s"
(
token_id_
[
number
<
4
>
{}]),
[
s_token_id5
]
"s"
(
token_id_
[
number
<
5
>
{}]),
[
s_token_id6
]
"s"
(
token_id_
[
number
<
6
>
{}]),
[
s_token_id7
]
"s"
(
token_id_
[
number
<
7
>
{}]),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_b_half
]
"s"
(
tile_offset_half_b_bytes
),
[
s_tile_os_b_half
]
"s"
(
tile_offset_half_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
...
@@ -335,7 +457,8 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
...
@@ -335,7 +457,8 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s6"
,
"s7"
,
"s40"
,
"s41"
,
"s42"
,
"s43"
,
"s44"
,
"s45"
,
"s6"
,
"s7"
,
"s20"
,
"s21"
,
"s22"
,
"s23"
,
"s24"
,
"s25"
,
"s26"
,
"s27"
,
"s28"
,
"s29"
,
"s30"
,
"s31"
,
"s38"
,
"s39"
,
"s34"
,
"s35"
,
"s40"
,
"s41"
,
"s42"
,
"s43"
,
"s44"
,
"s45"
,
"s46"
,
"s47"
,
"s48"
,
"s49"
,
"s50"
,
"s51"
,
"s52"
,
"s53"
,
"s54"
,
"s46"
,
"s47"
,
"s48"
,
"s49"
,
"s50"
,
"s51"
,
"s52"
,
"s53"
,
"s54"
,
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
...
@@ -367,7 +490,7 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
...
@@ -367,7 +490,7 @@ if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 &&
"v245"
,
"v246"
,
"v247"
,
"v248"
,
"v249"
,
"v250"
,
"v251"
,
"v252"
,
"v245"
,
"v246"
,
"v247"
,
"v248"
,
"v249"
,
"v250"
,
"v251"
,
"v252"
,
"v253"
,
"v254"
,
"v255"
"v253"
,
"v254"
,
"v255"
);
);
#pragma clang diagnostic pop
#pragma clang diagnostic pop
// clang-format on
// clang-format on
}
}
};
};
...
...
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
View file @
339a674b
...
@@ -31,8 +31,8 @@
...
@@ -31,8 +31,8 @@
" v_lshrrev_b32 v3, 6, v0
\n
"
" v_lshrrev_b32 v3, 6, v0
\n
"
" v_readfirstlane_b32 s7, v3
\n
"
" v_readfirstlane_b32 s7, v3
\n
"
" s_waitcnt vmcnt(24)
\n
"
" s_waitcnt vmcnt(24)
\n
"
"
buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen
\n
"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen
\n
"
"
buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024
\n
"
" v_mul_f32 v54, v128, v128
\n
"
" v_mul_f32 v54, v128, v128
\n
"
" v_mul_f32 v55, v129, v129
\n
"
" v_mul_f32 v55, v129, v129
\n
"
" v_mul_f32 v56, v130, v130
\n
"
" v_mul_f32 v56, v130, v130
\n
"
...
@@ -65,7 +65,7 @@
...
@@ -65,7 +65,7 @@
" v_mul_f32 v129, v129, v55
\n
"
" v_mul_f32 v129, v129, v55
\n
"
" v_mul_f32 v130, v130, v56
\n
"
" v_mul_f32 v130, v130, v56
\n
"
" v_mul_f32 v131, v131, v57
\n
"
" v_mul_f32 v131, v131, v57
\n
"
"
buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v132, v132
\n
"
" v_mul_f32 v54, v132, v132
\n
"
" v_mul_f32 v55, v133, v133
\n
"
" v_mul_f32 v55, v133, v133
\n
"
" v_mul_f32 v56, v134, v134
\n
"
" v_mul_f32 v56, v134, v134
\n
"
...
@@ -86,7 +86,7 @@
...
@@ -86,7 +86,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -99,7 +99,7 @@
...
@@ -99,7 +99,7 @@
" v_mul_f32 v133, v133, v55
\n
"
" v_mul_f32 v133, v133, v55
\n
"
" v_mul_f32 v134, v134, v56
\n
"
" v_mul_f32 v134, v134, v56
\n
"
" v_mul_f32 v135, v135, v57
\n
"
" v_mul_f32 v135, v135, v57
\n
"
"
buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen
\n
"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v136, v136
\n
"
" v_mul_f32 v54, v136, v136
\n
"
" v_mul_f32 v55, v137, v137
\n
"
" v_mul_f32 v55, v137, v137
\n
"
" v_mul_f32 v56, v138, v138
\n
"
" v_mul_f32 v56, v138, v138
\n
"
...
@@ -120,7 +120,7 @@
...
@@ -120,7 +120,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -133,7 +133,7 @@
...
@@ -133,7 +133,7 @@
" v_mul_f32 v137, v137, v55
\n
"
" v_mul_f32 v137, v137, v55
\n
"
" v_mul_f32 v138, v138, v56
\n
"
" v_mul_f32 v138, v138, v56
\n
"
" v_mul_f32 v139, v139, v57
\n
"
" v_mul_f32 v139, v139, v57
\n
"
"
buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v140, v140
\n
"
" v_mul_f32 v54, v140, v140
\n
"
" v_mul_f32 v55, v141, v141
\n
"
" v_mul_f32 v55, v141, v141
\n
"
" v_mul_f32 v56, v142, v142
\n
"
" v_mul_f32 v56, v142, v142
\n
"
...
@@ -154,7 +154,7 @@
...
@@ -154,7 +154,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -168,7 +168,7 @@
...
@@ -168,7 +168,7 @@
" v_mul_f32 v142, v142, v56
\n
"
" v_mul_f32 v142, v142, v56
\n
"
" v_mul_f32 v143, v143, v57
\n
"
" v_mul_f32 v143, v143, v57
\n
"
" s_waitcnt vmcnt(24)
\n
"
" s_waitcnt vmcnt(24)
\n
"
"
buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen
\n
"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v144, v144
\n
"
" v_mul_f32 v54, v144, v144
\n
"
" v_mul_f32 v55, v145, v145
\n
"
" v_mul_f32 v55, v145, v145
\n
"
" v_mul_f32 v56, v146, v146
\n
"
" v_mul_f32 v56, v146, v146
\n
"
...
@@ -189,7 +189,7 @@
...
@@ -189,7 +189,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -202,7 +202,7 @@
...
@@ -202,7 +202,7 @@
" v_mul_f32 v145, v145, v55
\n
"
" v_mul_f32 v145, v145, v55
\n
"
" v_mul_f32 v146, v146, v56
\n
"
" v_mul_f32 v146, v146, v56
\n
"
" v_mul_f32 v147, v147, v57
\n
"
" v_mul_f32 v147, v147, v57
\n
"
"
buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v148, v148
\n
"
" v_mul_f32 v54, v148, v148
\n
"
" v_mul_f32 v55, v149, v149
\n
"
" v_mul_f32 v55, v149, v149
\n
"
" v_mul_f32 v56, v150, v150
\n
"
" v_mul_f32 v56, v150, v150
\n
"
...
@@ -223,7 +223,7 @@
...
@@ -223,7 +223,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -236,7 +236,7 @@
...
@@ -236,7 +236,7 @@
" v_mul_f32 v149, v149, v55
\n
"
" v_mul_f32 v149, v149, v55
\n
"
" v_mul_f32 v150, v150, v56
\n
"
" v_mul_f32 v150, v150, v56
\n
"
" v_mul_f32 v151, v151, v57
\n
"
" v_mul_f32 v151, v151, v57
\n
"
"
buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen
\n
"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v152, v152
\n
"
" v_mul_f32 v54, v152, v152
\n
"
" v_mul_f32 v55, v153, v153
\n
"
" v_mul_f32 v55, v153, v153
\n
"
" v_mul_f32 v56, v154, v154
\n
"
" v_mul_f32 v56, v154, v154
\n
"
...
@@ -257,7 +257,7 @@
...
@@ -257,7 +257,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -270,7 +270,7 @@
...
@@ -270,7 +270,7 @@
" v_mul_f32 v153, v153, v55
\n
"
" v_mul_f32 v153, v153, v55
\n
"
" v_mul_f32 v154, v154, v56
\n
"
" v_mul_f32 v154, v154, v56
\n
"
" v_mul_f32 v155, v155, v57
\n
"
" v_mul_f32 v155, v155, v57
\n
"
"
buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v156, v156
\n
"
" v_mul_f32 v54, v156, v156
\n
"
" v_mul_f32 v55, v157, v157
\n
"
" v_mul_f32 v55, v157, v157
\n
"
" v_mul_f32 v56, v158, v158
\n
"
" v_mul_f32 v56, v158, v158
\n
"
...
@@ -291,7 +291,7 @@
...
@@ -291,7 +291,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072
\n
"
" s_add_u32 s12, %[s_tile_os_b_half], s12
\n
"
" s_add_u32 s12, %[s_tile_os_b_half], s12
\n
"
" s_addc_u32 s13, 0, s13
\n
"
" s_addc_u32 s13, 0, s13
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
...
@@ -307,7 +307,7 @@
...
@@ -307,7 +307,7 @@
" v_mul_f32 v158, v158, v56
\n
"
" v_mul_f32 v158, v158, v56
\n
"
" v_mul_f32 v159, v159, v57
\n
"
" v_mul_f32 v159, v159, v57
\n
"
" s_waitcnt vmcnt(24)
\n
"
" s_waitcnt vmcnt(24)
\n
"
"
buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen
\n
"
"buffer_load_dwordx4 acc[64:67], %[v_os_b0], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v160, v160
\n
"
" v_mul_f32 v54, v160, v160
\n
"
" v_mul_f32 v55, v161, v161
\n
"
" v_mul_f32 v55, v161, v161
\n
"
" v_mul_f32 v56, v162, v162
\n
"
" v_mul_f32 v56, v162, v162
\n
"
...
@@ -328,7 +328,7 @@
...
@@ -328,7 +328,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[68:71], %[v_os_b0], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -341,7 +341,7 @@
...
@@ -341,7 +341,7 @@
" v_mul_f32 v161, v161, v55
\n
"
" v_mul_f32 v161, v161, v55
\n
"
" v_mul_f32 v162, v162, v56
\n
"
" v_mul_f32 v162, v162, v56
\n
"
" v_mul_f32 v163, v163, v57
\n
"
" v_mul_f32 v163, v163, v57
\n
"
"
buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[72:75], %[v_os_b0], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v164, v164
\n
"
" v_mul_f32 v54, v164, v164
\n
"
" v_mul_f32 v55, v165, v165
\n
"
" v_mul_f32 v55, v165, v165
\n
"
" v_mul_f32 v56, v166, v166
\n
"
" v_mul_f32 v56, v166, v166
\n
"
...
@@ -362,7 +362,7 @@
...
@@ -362,7 +362,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[76:79], %[v_os_b0], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -375,7 +375,7 @@
...
@@ -375,7 +375,7 @@
" v_mul_f32 v165, v165, v55
\n
"
" v_mul_f32 v165, v165, v55
\n
"
" v_mul_f32 v166, v166, v56
\n
"
" v_mul_f32 v166, v166, v56
\n
"
" v_mul_f32 v167, v167, v57
\n
"
" v_mul_f32 v167, v167, v57
\n
"
"
buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen
\n
"
"buffer_load_dwordx4 acc[80:83], %[v_os_b1], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v168, v168
\n
"
" v_mul_f32 v54, v168, v168
\n
"
" v_mul_f32 v55, v169, v169
\n
"
" v_mul_f32 v55, v169, v169
\n
"
" v_mul_f32 v56, v170, v170
\n
"
" v_mul_f32 v56, v170, v170
\n
"
...
@@ -396,7 +396,7 @@
...
@@ -396,7 +396,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[84:87], %[v_os_b1], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -409,7 +409,7 @@
...
@@ -409,7 +409,7 @@
" v_mul_f32 v169, v169, v55
\n
"
" v_mul_f32 v169, v169, v55
\n
"
" v_mul_f32 v170, v170, v56
\n
"
" v_mul_f32 v170, v170, v56
\n
"
" v_mul_f32 v171, v171, v57
\n
"
" v_mul_f32 v171, v171, v57
\n
"
"
buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[88:91], %[v_os_b1], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v172, v172
\n
"
" v_mul_f32 v54, v172, v172
\n
"
" v_mul_f32 v55, v173, v173
\n
"
" v_mul_f32 v55, v173, v173
\n
"
" v_mul_f32 v56, v174, v174
\n
"
" v_mul_f32 v56, v174, v174
\n
"
...
@@ -430,7 +430,7 @@
...
@@ -430,7 +430,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[92:95], %[v_os_b1], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -444,7 +444,7 @@
...
@@ -444,7 +444,7 @@
" v_mul_f32 v174, v174, v56
\n
"
" v_mul_f32 v174, v174, v56
\n
"
" v_mul_f32 v175, v175, v57
\n
"
" v_mul_f32 v175, v175, v57
\n
"
" s_waitcnt vmcnt(24)
\n
"
" s_waitcnt vmcnt(24)
\n
"
"
buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen
\n
"
"buffer_load_dwordx4 acc[96:99], %[v_os_b2], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v176, v176
\n
"
" v_mul_f32 v54, v176, v176
\n
"
" v_mul_f32 v55, v177, v177
\n
"
" v_mul_f32 v55, v177, v177
\n
"
" v_mul_f32 v56, v178, v178
\n
"
" v_mul_f32 v56, v178, v178
\n
"
...
@@ -465,7 +465,7 @@
...
@@ -465,7 +465,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[100:103], %[v_os_b2], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -478,7 +478,7 @@
...
@@ -478,7 +478,7 @@
" v_mul_f32 v177, v177, v55
\n
"
" v_mul_f32 v177, v177, v55
\n
"
" v_mul_f32 v178, v178, v56
\n
"
" v_mul_f32 v178, v178, v56
\n
"
" v_mul_f32 v179, v179, v57
\n
"
" v_mul_f32 v179, v179, v57
\n
"
"
buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[104:107], %[v_os_b2], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v180, v180
\n
"
" v_mul_f32 v54, v180, v180
\n
"
" v_mul_f32 v55, v181, v181
\n
"
" v_mul_f32 v55, v181, v181
\n
"
" v_mul_f32 v56, v182, v182
\n
"
" v_mul_f32 v56, v182, v182
\n
"
...
@@ -499,7 +499,7 @@
...
@@ -499,7 +499,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[108:111], %[v_os_b2], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -512,7 +512,7 @@
...
@@ -512,7 +512,7 @@
" v_mul_f32 v181, v181, v55
\n
"
" v_mul_f32 v181, v181, v55
\n
"
" v_mul_f32 v182, v182, v56
\n
"
" v_mul_f32 v182, v182, v56
\n
"
" v_mul_f32 v183, v183, v57
\n
"
" v_mul_f32 v183, v183, v57
\n
"
"
buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen
\n
"
"buffer_load_dwordx4 acc[112:115], %[v_os_b3], s[12:15], 0 offen
\n
"
" v_mul_f32 v54, v184, v184
\n
"
" v_mul_f32 v54, v184, v184
\n
"
" v_mul_f32 v55, v185, v185
\n
"
" v_mul_f32 v55, v185, v185
\n
"
" v_mul_f32 v56, v186, v186
\n
"
" v_mul_f32 v56, v186, v186
\n
"
...
@@ -533,7 +533,7 @@
...
@@ -533,7 +533,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[116:119], %[v_os_b3], s[12:15], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -546,7 +546,7 @@
...
@@ -546,7 +546,7 @@
" v_mul_f32 v185, v185, v55
\n
"
" v_mul_f32 v185, v185, v55
\n
"
" v_mul_f32 v186, v186, v56
\n
"
" v_mul_f32 v186, v186, v56
\n
"
" v_mul_f32 v187, v187, v57
\n
"
" v_mul_f32 v187, v187, v57
\n
"
"
buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[120:123], %[v_os_b3], s[12:15], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v188, v188
\n
"
" v_mul_f32 v54, v188, v188
\n
"
" v_mul_f32 v55, v189, v189
\n
"
" v_mul_f32 v55, v189, v189
\n
"
" v_mul_f32 v56, v190, v190
\n
"
" v_mul_f32 v56, v190, v190
\n
"
...
@@ -567,7 +567,7 @@
...
@@ -567,7 +567,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
"
buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[124:127], %[v_os_b3], s[12:15], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -644,7 +644,7 @@
...
@@ -644,7 +644,7 @@
" v_mul_f32 v189, %[smq_scale1], v189 row_newbcast:13
\n
"
" v_mul_f32 v189, %[smq_scale1], v189 row_newbcast:13
\n
"
" v_mul_f32 v190, %[smq_scale1], v190 row_newbcast:14
\n
"
" v_mul_f32 v190, %[smq_scale1], v190 row_newbcast:14
\n
"
" v_mul_f32 v191, %[smq_scale1], v191 row_newbcast:15
\n
"
" v_mul_f32 v191, %[smq_scale1], v191 row_newbcast:15
\n
"
"
buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen
\n
"
"
;--
buffer_load_dword v12, %[v_os_dq], s[16:19], 0 offen
\n
"
" v_mov_b32 v22, 0x358637bd
\n
"
" v_mov_b32 v22, 0x358637bd
\n
"
" v_mov_b32 v23, 0x358637bd
\n
"
" v_mov_b32 v23, 0x358637bd
\n
"
" v_max3_f32 v22, abs(v128), abs(v129), v22
\n
"
" v_max3_f32 v22, abs(v128), abs(v129), v22
\n
"
...
@@ -974,3 +974,5 @@
...
@@ -974,3 +974,5 @@
#undef _UK_PK_CVT_
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#undef _UK_ATOMIC_ADD_
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
View file @
339a674b
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_3.inc
0 → 100644
View file @
339a674b
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
View file @
339a674b
...
@@ -19,9 +19,9 @@
...
@@ -19,9 +19,9 @@
" v_mul_f32 "
a2
", "
gq
", "
a2
" row_newbcast: "
brd2
"
\n
"
\
" v_mul_f32 "
a2
", "
gq
", "
a2
" row_newbcast: "
brd2
"
\n
"
\
" v_mul_f32 "
a3
", "
gq
", "
a3
" row_newbcast:"
brd3
"
\n
"
" v_mul_f32 "
a3
", "
gq
", "
a3
" row_newbcast:"
brd3
"
\n
"
" s_mov_b32 s22, %[a_bound]
\n
"
"s_mov_b32 s20, %[s_res_a0]
\n
"
"s_mov_b32 s20, %[s_res_a0]
\n
"
"s_mov_b32 s21, %[s_res_a1]
\n
"
"s_mov_b32 s21, %[s_res_a1]
\n
"
"s_mov_b32 s22, %[s_res_a2]
\n
"
"s_mov_b32 s23, %[s_res_a3]
\n
"
"s_mov_b32 s23, %[s_res_a3]
\n
"
"s_mov_b32 s24, %[s_res_b0]
\n
"
"s_mov_b32 s24, %[s_res_b0]
\n
"
"s_mov_b32 s25, %[s_res_b1]
\n
"
"s_mov_b32 s25, %[s_res_b1]
\n
"
...
@@ -110,38 +110,38 @@
...
@@ -110,38 +110,38 @@
" s_add_u32 s20, s57, s20
\n
"
" s_add_u32 s20, s57, s20
\n
"
" s_addc_u32 s21, 0, s21
\n
"
" s_addc_u32 s21, 0, s21
\n
"
"; -- prefetch B0
\n
"
"; -- prefetch B0
\n
"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[24:27], 0 offen
\n
"
"
buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[24:27], 0 offen
\n
"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[24:27], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[24:27], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[24:27], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[24:27], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[24:27], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[24:27], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[24:27], 0 offen
\n
"
"
buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[24:27], 0 offen
\n
"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[24:27], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[24:27], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[24:27], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[24:27], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[24:27], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[24:27], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[24:27], 0 offen
\n
"
"
buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[24:27], 0 offen
\n
"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[24:27], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[24:27], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[24:27], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[24:27], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[24:27], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[24:27], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[24:27], 0 offen
\n
"
"
buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[24:27], 0 offen
\n
"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[24:27], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[24:27], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[24:27], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[24:27], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[24:27], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[24:27], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[24:27], 0 offen
\n
"
"
buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[24:27], 0 offen
\n
"
"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[24:27], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[24:27], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[24:27], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[24:27], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[24:27], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[24:27], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[24:27], 0 offen
\n
"
"
buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[24:27], 0 offen
\n
"
"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[24:27], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[24:27], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[24:27], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[24:27], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[24:27], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[24:27], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[24:27], 0 offen
\n
"
"
buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[24:27], 0 offen
\n
"
"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[24:27], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[24:27], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[24:27], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[24:27], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[24:27], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[24:27], 0 offen offset:3072
\n
"
"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[24:27], 0 offen
\n
"
"
buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[24:27], 0 offen
\n
"
"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[24:27], 0 offen offset:1024
\n
"
"
buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[24:27], 0 offen offset:1024
\n
"
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[24:27], 0 offen offset:2048
\n
"
"
buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[24:27], 0 offen offset:2048
\n
"
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[24:27], 0 offen offset:3072
\n
"
"
buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[24:27], 0 offen offset:3072
\n
"
"s_add_u32 s24, s58, s24
\n
"
"s_add_u32 s24, s58, s24
\n
"
"s_addc_u32 s25, 0, s25
\n
"
"s_addc_u32 s25, 0, s25
\n
"
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
339a674b
...
@@ -237,12 +237,23 @@ struct FusedMoeGemmKernel
...
@@ -237,12 +237,23 @@ struct FusedMoeGemmKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
{
if
constexpr
(
UseUK
)
if
constexpr
(
UseUK
)
{
{
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
65536
];
// index_t s_size = GetSmemSize();
// ADataType{}.aaa();
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
// __builtin_amdgcn_sched_barrier(0);
// if(threadIdx.x == 0){
// printf("num_sorted_tiles %d\n", num_sorted_tiles);
// printf("data type :%s\n", t2s<ADataType>::name);
// printf("\nblockIdx.x :%x, blockIdx.y :%x,\n", blockIdx.x, blockIdx.y);
// __builtin_amdgcn_sched_barrier(0);
// }
// __builtin_amdgcn_sched_barrier(0);
num_sorted_tiles
=
num_sorted_tiles
/
BlockShape
::
Block_M0
;
num_sorted_tiles
=
num_sorted_tiles
/
BlockShape
::
Block_M0
;
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
View file @
339a674b
...
@@ -74,7 +74,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -74,7 +74,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
));
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
));
}
}
// this is the thread-offset along row/col
// this is the thread-offset along row/col
...
@@ -159,15 +159,22 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -159,15 +159,22 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
template
<
typename
ROW_IDS
>
template
<
typename
ROW_IDS
>
CK_TILE_DEVICE
auto
GetAScale
(
const
ROW_IDS
row_ids_mma
,
CK_TILE_DEVICE
auto
GetAScale
(
const
ROW_IDS
row_ids_mma
,
const
AScaleDataType
*
a_scale_ptr
)
// const AScaleDataType* a_scale_ptr, index_t num_tokens_)
index_t
num_tokens_
)
{
{
constexpr
index_t
n_size
=
row_ids_mma
.
size
();
constexpr
index_t
n_size
=
row_ids_mma
.
size
();
array
<
TopkWeightDataType
,
n_size
>
w
;
array
<
TopkWeightDataType
,
n_size
>
w
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
auto
row_id
=
row_ids_mma
[
i
]
&
0xffffff
;
auto
row_id
=
row_ids_mma
[
i
]
&
0xffffff
;
auto
itp_k
=
row_ids_mma
[
i
]
>>
24
;
if
(
row_id
>=
num_tokens_
)
w
.
at
(
i
)
=
a_scale_ptr
[
row_id
*
5
+
itp_k
];
{
w
.
at
(
i
)
=
0.
f
;
}
else
{
w
.
at
(
i
)
=
1.
f
;
// auto itp_k = row_ids_mma[i] >> 24;
// w.at(i) = a_scale_ptr[row_id * 5+itp_k];
}
});
});
return
w
;
return
w
;
...
@@ -247,7 +254,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -247,7 +254,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Warp_K0
;
// divide K in W
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Warp_K0
;
// divide K in W
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShape
::
Warp_N1
;
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShape
::
Warp_N1
;
index_t
kr_1
=
shared_intermediate_size_1
/
BlockShape
::
Warp_K1
;
index_t
kr_1
=
shared_intermediate_size_1
/
BlockShape
::
Warp_K1
;
// if(threadIdx.x == 31 && blockIdx.x == 0 && blockIdx.y == 0)
// {
// printf("\nWarpPerBlock_N0 :%x, WarpPerBlock_M0:%x,\n", BlockShape::WarpPerBlock_N0
// , BlockShape::WarpPerBlock_M0);
// }
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
;
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
;
...
@@ -271,7 +283,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -271,7 +283,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
row_coords_a_mma
,
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_token_ids_ptr
));
row_coords_a_mma
,
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_token_ids_ptr
));
auto
token_id
=
generate_tuple
(
auto
token_id
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
return
(
row_ids_a
[
i
]
)
&
0xffffff
;
return
(
row_ids_a
[
i
]
&
0xffffff
)
;
},
},
number
<
row_ids_a
.
size
()
>
{});
number
<
row_ids_a
.
size
()
>
{});
auto
a_coords
=
generate_tuple
(
auto
a_coords
=
generate_tuple
(
...
@@ -385,10 +397,16 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -385,10 +397,16 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
threadIdx
.
x
%
(
BlockShape
::
Block_N1
/
2
/
kAlignmentO
)
*
kAlignmentO
;
threadIdx
.
x
%
(
BlockShape
::
Block_N1
/
2
/
kAlignmentO
)
*
kAlignmentO
;
},
},
number
<
row_ids_a
.
size
()
>
{});
number
<
row_ids_a
.
size
()
>
{});
auto
o_flags
=
auto
o_flags
=
generate_tuple
([
&
](
auto
i
)
{
return
cmp_lt_to_exec
(
token_id
[
i
],
kargs
.
num_tokens
);
},
generate_tuple
([
&
](
auto
i
)
{
return
cmp_lt_to_exec
(
token_id
[
i
],
kargs
.
num_tokens
);
},
number
<
row_ids_a
.
size
()
>
{});
// generate_tuple([&](auto i) {
// if (__builtin_amdgcn_readfirstlane(token_id[i]) < kargs.num_tokens)
// {return 0xffffffffffffffff;}
// else
// {return uint32x2_t 0;}
// },
number
<
token_id
.
size
()
>
{});
auto
o_res
=
auto
o_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
),
make_wave_buffer_resource
(
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
),
...
@@ -398,13 +416,59 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -398,13 +416,59 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto
w_scale
=
GetWeightScale
(
auto
w_scale
=
GetWeightScale
(
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
auto
a_scale
=
GetAScale
(
auto
a_scale
=
GetAScale
(
row_ids_a_mma
,
reinterpret_cast
<
const
AScaleDataType
*>
(
kargs
.
a_scale_ptr
));
// row_ids_a_mma, reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr), kargs.num_tokens );
row_ids_a_mma
,
kargs
.
num_tokens
);
auto
gqsmq_coords
=
GetColCoords_GQSMQ
(
intermediate_tile_id
*
BlockShape
::
Block_K1
);
auto
gqsmq_coords
=
GetColCoords_GQSMQ
(
intermediate_tile_id
*
BlockShape
::
Block_K1
);
auto
dq_coords
=
gqsmq_coords
[
0
];
//only one for this tiling
auto
dq_coords
=
gqsmq_coords
[
0
];
//only one for this tiling
auto
gq_scale
=
GetGQScale
(
auto
gq_scale
=
GetGQScale
(
gqsmq_coords
,
(
reinterpret_cast
<
const
GScaleDataType
*>
(
kargs
.
g_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
shared_intermediate_size_0
));
gqsmq_coords
,
(
reinterpret_cast
<
const
GScaleDataType
*>
(
kargs
.
g_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
shared_intermediate_size_0
));
auto
smq_scale
=
GetSMQScale
(
auto
smq_scale
=
GetSMQScale
(
gqsmq_coords
,
(
reinterpret_cast
<
const
YSmoothScaleDataType
*>
(
kargs
.
y_smooth_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
shared_intermediate_size_0
));
gqsmq_coords
,
(
reinterpret_cast
<
const
YSmoothScaleDataType
*>
(
kargs
.
y_smooth_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
shared_intermediate_size_0
));
if
(
threadIdx
.
x
==
95
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
)
{
printf
(
"
\n
blockIdx.x :%x, blockIdx.y :%x, d ptr: %p, wg d ptr :%x%x,gemm0 done
\n
"
,
blockIdx
.
x
,
blockIdx
.
y
,
kargs
.
d_ptr
,
d_res
[
1
],
d_res
[
0
]);
// // printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// // printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
// // printf("\n -----------thread id %x--- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", hipThreadIdx_x , token_id[number<0>{}],token_id[number<1>{}],token_id[number<2>{}],token_id[number<3>{}], token_id[number<4>{}],token_id[number<5>{}],token_id[number<6>{}],token_id[number<7>{}]);
// // printf("\n -----------thread id %x--- - token_id , 7:%x,, \n", hipThreadIdx_x , token_id[number<7>{}]);
// // printf("\n -------------- - exec 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", o_flags[number<0>{}][0],o_flags[number<1>{}][0],o_flags[number<2>{}][0],o_flags[number<3>{}][0], o_flags[number<4>{}][0],o_flags[number<5>{}][0],o_flags[number<6>{}][0],o_flags[number<7>{}][0]);
printf
(
"
\n
token id :%x,%x,%x,%x, %x,%x,%x,%x
\n
d_coords: %x,%x,%x,%x,
\n
row_idx: %x,%x,%x,%x, %x,%x,%x,%x
\n
o_flags:%x,%x,%x,%x, %x,%x,%x,%x
\n
"
,
token_id
[
number
<
0
>
{}],
token_id
[
number
<
1
>
{}],
token_id
[
number
<
2
>
{}],
token_id
[
number
<
3
>
{}],
token_id
[
number
<
4
>
{}],
token_id
[
number
<
5
>
{}],
token_id
[
number
<
6
>
{}],
token_id
[
number
<
7
>
{}],
d_coords
[
number
<
0
>
{}],
d_coords
[
number
<
1
>
{}],
d_coords
[
number
<
2
>
{}],
d_coords
[
number
<
3
>
{}],
// d_coords[number<4>{}],
// d_coords[number<5>{}],
// d_coords[number<6>{}],
// d_coords[number<7>{}],
row_ids_a
[
number
<
0
>
{}],
row_ids_a
[
number
<
1
>
{}],
row_ids_a
[
number
<
2
>
{}],
row_ids_a
[
number
<
3
>
{}],
row_ids_a
[
number
<
4
>
{}],
row_ids_a
[
number
<
5
>
{}],
row_ids_a
[
number
<
6
>
{}],
row_ids_a
[
number
<
7
>
{}],
o_flags
[
number
<
0
>
{}][
0
],
o_flags
[
number
<
1
>
{}][
0
],
o_flags
[
number
<
2
>
{}][
0
],
o_flags
[
number
<
3
>
{}][
0
],
o_flags
[
number
<
4
>
{}][
0
],
o_flags
[
number
<
5
>
{}][
0
],
o_flags
[
number
<
6
>
{}][
0
],
o_flags
[
number
<
7
>
{}][
0
]);
// return;
}
__builtin_amdgcn_sched_barrier
(
0
);
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
// auto acc_0= uk_0(
// auto acc_0= uk_0(
uk_0
(
a_scale
,
uk_0
(
a_scale
,
...
@@ -418,17 +482,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -418,17 +482,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
smem
,
smem
,
kargs
.
hidden_size
,
kargs
.
hidden_size
,
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
16
*
256
,
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
kargs
.
num_tokens
*
kargs
.
stride_token
);
// tile offset for B matrix each unroll
if
(
hipBlockIdx_x
==
1
&&
hipBlockIdx_y
==
1
&&
hipBlockIdx_z
==
0
&&
// return;
hipThreadIdx_x
==
64
)
__builtin_amdgcn_sched_barrier
(
0
);
{
printf
(
"
\n
gemm0 done
\n
"
);
// // sweep_tile(
// printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]);
// printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]);
printf
(
"
\n
-------------- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,,
\n
"
,
token_id
[
number
<
0
>
{}],
token_id
[
number
<
1
>
{}],
token_id
[
number
<
2
>
{}],
token_id
[
number
<
3
>
{}],
token_id
[
number
<
4
>
{}],
token_id
[
number
<
5
>
{}],
token_id
[
number
<
6
>
{}],
token_id
[
number
<
7
>
{}]);
}
// sweep_tile(
// acc_0,
// acc_0,
// [&](auto idx0, auto idx1) {
// [&](auto idx0, auto idx1) {
// fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
// fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
...
@@ -449,6 +508,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -449,6 +508,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
uk_1
(
uk_1
(
// dq_res,
// dq_res,
// d_res,
// d_res,
token_id
,
dq_coords
,
dq_coords
,
d_coords
,
d_coords
,
o_res
,
o_res
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment