Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6f7d1272
Commit
6f7d1272
authored
Jan 05, 2025
by
shengnxu
Browse files
changed all the scale outside except for uq
parent
9a46c0e7
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
671 additions
and
834 deletions
+671
-834
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
...ps/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
+7
-17
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
...flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
+20
-15
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
+102
-66
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
+436
-466
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
...lock/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
+45
-241
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
+61
-29
No files found.
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
View file @
6f7d1272
...
@@ -245,13 +245,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
...
@@ -245,13 +245,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
// TODO: need paired with tile_window_linear!
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// TODO: need call init_raw() before call this function!
template
<
typename
A
Token_id
,
typename
AQRes
,
typename
DQRes
,
typename
GQRes
,
typename
SMQRes
,
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
>
template
<
typename
A
scale
,
typename
GQscale
,
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
>
CK_TILE_DEVICE
auto
CK_TILE_DEVICE
auto
operator
()(
const
AToken_id
&
row_ids_a_
,
operator
()(
const
Ascale
&
a_scale_
,
const
AQRes
&
res_aq
,
const
GQscale
&
gq_scale_
,
const
DQRes
&
res_dq
,
const
GQRes
&
res_gq
,
const
SMQRes
&
res_smq
,
const
ARes
&
res_a
,
const
ARes
&
res_a
,
const
ACoords
&
cached_coords_a
,
const
ACoords
&
cached_coords_a
,
const
BRes
&
res_b
,
const
BRes
&
res_b
,
...
@@ -263,7 +260,6 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
...
@@ -263,7 +260,6 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
{
{
static_assert
(
ACoords
::
size
()
==
Block_M
*
Block_K
/
BlockSize
/
4
/*2x per dword*/
);
// 8
static_assert
(
ACoords
::
size
()
==
Block_M
*
Block_K
/
BlockSize
/
4
/*2x per dword*/
);
// 8
static_assert
(
BCoords
::
size
()
==
Repeat_N
);
static_assert
(
BCoords
::
size
()
==
Repeat_N
);
static_assert
(
AToken_id
::
size
()
==
Repeat_M
);
static_assert
(
Ascale
::
size
()
==
Repeat_M
);
static_assert
(
Ascale
::
size
()
==
Repeat_M
);
auto
a_sst
=
make_tile_window
(
auto
a_sst
=
make_tile_window
(
...
@@ -372,10 +368,6 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
...
@@ -372,10 +368,6 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
register
int
v_z61
asm
(
"v189"
)
=
0
;
register
int
v_z61
asm
(
"v189"
)
=
0
;
register
int
v_z62
asm
(
"v190"
)
=
0
;
register
int
v_z62
asm
(
"v190"
)
=
0
;
register
int
v_z63
asm
(
"v191"
)
=
0
;
register
int
v_z63
asm
(
"v191"
)
=
0
;
index_t
temp0
=
static_cast
<
index_t
>
(
row_ids_a_
[
number
<
0
>
{}]);
index_t
temp1
=
static_cast
<
index_t
>
(
row_ids_a_
[
number
<
1
>
{}]);
// B nr->kr
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
#pragma clang diagnostic ignored "-Winline-asm"
...
@@ -449,13 +441,11 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
...
@@ -449,13 +441,11 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
[
c61
]
"+v"
(
v_z61
),
[
c61
]
"+v"
(
v_z61
),
[
c62
]
"+v"
(
v_z62
),
[
c62
]
"+v"
(
v_z62
),
[
c63
]
"+v"
(
v_z63
),
[
c63
]
"+v"
(
v_z63
),
[
v_token_id0
]
"+v"
(
temp0
),
[
v_token_id1
]
"+v"
(
temp1
),
[
s_mem_
]
"+r"
(
smem
)
[
s_mem_
]
"+r"
(
smem
)
:
[
s_res_aq
]
"
s
"
(
res_aq
),
:
[
a_scale0
]
"
v
"
(
a_scale_
[
0
]
),
[
s_res_dq
]
"
s
"
(
res_dq
),
[
a_scale1
]
"
v
"
(
a_scale_
[
1
]
),
[
s_res_gq
]
"
s
"
(
res_gq
),
[
gq_scale0
]
"
v
"
(
gq_scale_
[
0
]
),
[
s_res_smq
]
"
s
"
(
res_smq
),
[
gq_scale1
]
"
v
"
(
gq_scale_
[
1
]
),
[
s_res_a
]
"s"
(
res_a
),
[
s_res_a
]
"s"
(
res_a
),
// [s_res_a1]"s"(res_a[1]),
// [s_res_a1]"s"(res_a[1]),
// [s_res_a2]"s"(res_a[2]),
// [s_res_a2]"s"(res_a[2]),
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
View file @
6f7d1272
...
@@ -80,21 +80,25 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -80,21 +80,25 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template
<
typename
DQRes
,
template
<
typename
DQRes
,
typename
BRes
,
typename
BRes
,
typename
DQCoords
,
typename
BCoords
,
typename
BCoords
,
typename
ORes
,
typename
ORes
,
typename
OCoords
,
typename
OCoords
,
typename
OFlags
>
typename
OFlags
,
// typename ScaleTensor>
typename
ScaleTensor
,
typename
YScaleTensor
>
CK_TILE_DEVICE
auto
CK_TILE_DEVICE
auto
operator
()(
const
DQRes
&
res_dq
,
operator
()(
const
DQRes
&
res_dq
,
const
BRes
&
res_b
,
const
BRes
&
res_b
,
const
DQCoords
&
cached_coords_dq
,
const
BCoords
&
cached_coords_b
,
const
BCoords
&
cached_coords_b
,
const
ORes
&
res_o
,
const
ORes
&
res_o
,
const
OCoords
&
cached_coords_o
,
const
OCoords
&
cached_coords_o
,
const
OFlags
&
o_flags
,
// this should be in sgpr
const
OFlags
&
o_flags
,
// this should be in sgpr
CK_TILE_LDS_ADDR
void
*
smem
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
n
,
// loop along n dim
index_t
n
,
// loop along n dim
// const ScaleTensor& scale_,
const
ScaleTensor
&
scale_
,
const
YScaleTensor
&
smq_scale_
,
index_t
tile_offset_dq
,
index_t
tile_offset_dq
,
index_t
tile_offset_b
,
// stride b is fixed to blockKr * blockW, but still can adjust
index_t
tile_offset_b
,
// stride b is fixed to blockKr * blockW, but still can adjust
index_t
tile_offset_half_b
,
//splited load alone K in to 2 part
index_t
tile_offset_half_b
,
//splited load alone K in to 2 part
...
@@ -108,9 +112,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -108,9 +112,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
const
index_t
tile_stride_o_bytes
=
tile_offset_o
*
sizeof
(
ODataType
);
const
index_t
tile_stride_o_bytes
=
tile_offset_o
*
sizeof
(
ODataType
);
const
index_t
tile_stride_dq_bytes
=
tile_offset_dq
*
sizeof
(
DScaleDataType
);
const
index_t
tile_stride_dq_bytes
=
tile_offset_dq
*
sizeof
(
DScaleDataType
);
//
static_assert(ScaleTensor::size() == 2);
static_assert
(
ScaleTensor
::
size
()
==
2
);
//
float s0 = scale_[number<0>{}];
float
s0
=
scale_
[
number
<
0
>
{}];
//
float s1 = scale_[number<1>{}];
float
s1
=
scale_
[
number
<
1
>
{}];
index_t
loop_cnt
=
n
;
index_t
loop_cnt
=
n
;
...
@@ -220,8 +224,10 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -220,8 +224,10 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [v_sld_y_os]"v"(sld_y_os),
// [v_sld_y_os]"v"(sld_y_os),
// [v_sfl_sld]"v"(sfl_sld),
// [v_sfl_sld]"v"(sfl_sld),
// [v_sfl_sst]"v"(sfl_sst),
// [v_sfl_sst]"v"(sfl_sst),
[
smq_scale0
]
"s"
(
smq_scale_
[
0
]),
[
smq_scale1
]
"s"
(
smq_scale_
[
1
]),
[
s_res_dq
]
"s"
(
res_dq
),
[
s_res_dq
]
"s"
(
res_dq
),
[
s_res_o0
]
"s"
(
res_o
[
0
]),
[
s_res_o0
]
"s"
(
res_o
[
0
]),
[
s_res_o1
]
"s"
(
res_o
[
1
]),
[
s_res_o1
]
"s"
(
res_o
[
1
]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
//[s_res_o3]"s"(res_o[3]),
...
@@ -229,6 +235,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -229,6 +235,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [s_res_b1]"s"(res_b[1]),
// [s_res_b1]"s"(res_b[1]),
// [s_res_b2]"s"(res_b[2]),
// [s_res_b2]"s"(res_b[2]),
// [s_res_b3]"s"(res_b[3]),
// [s_res_b3]"s"(res_b[3]),
[
v_os_dq
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_dq
*
sizeof
(
DScaleDataType
))),
[
v_os_o0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
0
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
0
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
1
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
1
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
2
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
2
>
{}]
*
sizeof
(
ODataType
))),
...
@@ -293,8 +300,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -293,8 +300,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v56"
,
"v57"
,
"v64"
,
"v56"
,
"v57"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
...
@@ -366,8 +372,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -366,8 +372,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [v_sfl_sld]"v"(sfl_sld),
// [v_sfl_sld]"v"(sfl_sld),
// [v_sfl_sst]"v"(sfl_sst),
// [v_sfl_sst]"v"(sfl_sst),
[
s_res_dq
]
"s"
(
res_dq
),
[
s_res_dq
]
"s"
(
res_dq
),
[
s_res_o0
]
"s"
(
res_o
[
0
]),
[
s_res_o0
]
"s"
(
res_o
[
0
]),
[
s_res_o1
]
"s"
(
res_o
[
1
]),
[
s_res_o1
]
"s"
(
res_o
[
1
]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
//[s_res_o3]"s"(res_o[3]),
[
s_res_d
]
"s"
(
res_b
),
[
s_res_d
]
"s"
(
res_b
),
...
@@ -390,8 +396,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -390,8 +396,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[
s_tile_os_b_half
]
"s"
(
tile_offset_half_b_bytes
),
[
s_tile_os_b_half
]
"s"
(
tile_offset_half_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
s_tile_os_dq
]
"s"
(
tile_stride_dq_bytes
),
[
s_tile_os_dq
]
"s"
(
tile_stride_dq_bytes
),
//
[scale_0]"v"(s0),
[
scale_0
]
"v"
(
s0
),
//
[scale_1]"v"(s1),
[
scale_1
]
"v"
(
s1
),
// [v_nan_lo]"v"(nan_lo),
// [v_nan_lo]"v"(nan_lo),
// [v_nan_hi]"v"(nan_hi),
// [v_nan_hi]"v"(nan_hi),
[
s_execflag_0
]
"s"
(
o_flags
[
number
<
0
>
{}]),
[
s_execflag_0
]
"s"
(
o_flags
[
number
<
0
>
{}]),
...
@@ -438,8 +444,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
...
@@ -438,8 +444,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v56"
,
"v57"
,
"v64"
,
"v56"
,
"v57"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
...
...
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
View file @
6f7d1272
...
@@ -27,8 +27,12 @@
...
@@ -27,8 +27,12 @@
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
#endif
" v_and_b32 v0, 0x3f, v0
\n
"
" v_lshrrev_b32 v3, 6, v0
\n
"
" v_readfirstlane_b32 s7, v3
\n
"
" s_waitcnt vmcnt(24)
\n
"
" s_waitcnt vmcnt(24)
\n
"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], %[s_res_d], 0 offen
\n
"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], %[s_res_d], 0 offen
\n
"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], %[s_res_d], 0 offen offset:1024
\n
"
" v_mul_f32 v54, v128, v128
\n
"
" v_mul_f32 v54, v128, v128
\n
"
" v_mul_f32 v55, v129, v129
\n
"
" v_mul_f32 v55, v129, v129
\n
"
" v_mul_f32 v56, v130, v130
\n
"
" v_mul_f32 v56, v130, v130
\n
"
...
@@ -49,7 +53,6 @@
...
@@ -49,7 +53,6 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], %[s_res_d], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
@@ -577,71 +580,71 @@
...
@@ -577,71 +580,71 @@
" v_mul_f32 v189, v189, v55
\n
"
" v_mul_f32 v189, v189, v55
\n
"
" v_mul_f32 v190, v190, v56
\n
"
" v_mul_f32 v190, v190, v56
\n
"
" v_mul_f32 v191, v191, v57
\n
"
" v_mul_f32 v191, v191, v57
\n
"
" v_mul_f32 v128,
v18
, v128 row_newbcast:0
\n
"
" v_mul_f32 v128,
%[smq_scale0]
, v128 row_newbcast:0
\n
"
" v_mul_f32 v129,
v18
, v129 row_newbcast:1
\n
"
" v_mul_f32 v129,
%[smq_scale0]
, v129 row_newbcast:1
\n
"
" v_mul_f32 v130,
v18
, v130 row_newbcast:2
\n
"
" v_mul_f32 v130,
%[smq_scale0]
, v130 row_newbcast:2
\n
"
" v_mul_f32 v131,
v18
, v131 row_newbcast:3
\n
"
" v_mul_f32 v131,
%[smq_scale0]
, v131 row_newbcast:3
\n
"
" v_mul_f32 v132,
v18
, v132 row_newbcast:0
\n
"
" v_mul_f32 v132,
%[smq_scale0]
, v132 row_newbcast:0
\n
"
" v_mul_f32 v133,
v18
, v133 row_newbcast:1
\n
"
" v_mul_f32 v133,
%[smq_scale0]
, v133 row_newbcast:1
\n
"
" v_mul_f32 v134,
v18
, v134 row_newbcast:2
\n
"
" v_mul_f32 v134,
%[smq_scale0]
, v134 row_newbcast:2
\n
"
" v_mul_f32 v135,
v18
, v135 row_newbcast:3
\n
"
" v_mul_f32 v135,
%[smq_scale0]
, v135 row_newbcast:3
\n
"
" v_mul_f32 v136,
v18
, v136 row_newbcast:4
\n
"
" v_mul_f32 v136,
%[smq_scale0]
, v136 row_newbcast:4
\n
"
" v_mul_f32 v137,
v18
, v137 row_newbcast:5
\n
"
" v_mul_f32 v137,
%[smq_scale0]
, v137 row_newbcast:5
\n
"
" v_mul_f32 v138,
v18
, v138 row_newbcast:6
\n
"
" v_mul_f32 v138,
%[smq_scale0]
, v138 row_newbcast:6
\n
"
" v_mul_f32 v139,
v18
, v139 row_newbcast:7
\n
"
" v_mul_f32 v139,
%[smq_scale0]
, v139 row_newbcast:7
\n
"
" v_mul_f32 v140,
v18
, v140 row_newbcast:4
\n
"
" v_mul_f32 v140,
%[smq_scale0]
, v140 row_newbcast:4
\n
"
" v_mul_f32 v141,
v18
, v141 row_newbcast:5
\n
"
" v_mul_f32 v141,
%[smq_scale0]
, v141 row_newbcast:5
\n
"
" v_mul_f32 v142,
v18
, v142 row_newbcast:6
\n
"
" v_mul_f32 v142,
%[smq_scale0]
, v142 row_newbcast:6
\n
"
" v_mul_f32 v143,
v18
, v143 row_newbcast:7
\n
"
" v_mul_f32 v143,
%[smq_scale0]
, v143 row_newbcast:7
\n
"
" v_mul_f32 v144,
v18
, v144 row_newbcast:8
\n
"
" v_mul_f32 v144,
%[smq_scale0]
, v144 row_newbcast:8
\n
"
" v_mul_f32 v145,
v18
, v145 row_newbcast:9
\n
"
" v_mul_f32 v145,
%[smq_scale0]
, v145 row_newbcast:9
\n
"
" v_mul_f32 v146,
v18
, v146 row_newbcast:10
\n
"
" v_mul_f32 v146,
%[smq_scale0]
, v146 row_newbcast:10
\n
"
" v_mul_f32 v147,
v18
, v147 row_newbcast:11
\n
"
" v_mul_f32 v147,
%[smq_scale0]
, v147 row_newbcast:11
\n
"
" v_mul_f32 v148,
v18
, v148 row_newbcast:8
\n
"
" v_mul_f32 v148,
%[smq_scale0]
, v148 row_newbcast:8
\n
"
" v_mul_f32 v149,
v18
, v149 row_newbcast:9
\n
"
" v_mul_f32 v149,
%[smq_scale0]
, v149 row_newbcast:9
\n
"
" v_mul_f32 v150,
v18
, v150 row_newbcast:10
\n
"
" v_mul_f32 v150,
%[smq_scale0]
, v150 row_newbcast:10
\n
"
" v_mul_f32 v151,
v18
, v151 row_newbcast:11
\n
"
" v_mul_f32 v151,
%[smq_scale0]
, v151 row_newbcast:11
\n
"
" v_mul_f32 v152,
v18
, v152 row_newbcast:12
\n
"
" v_mul_f32 v152,
%[smq_scale0]
, v152 row_newbcast:12
\n
"
" v_mul_f32 v153,
v18
, v153 row_newbcast:13
\n
"
" v_mul_f32 v153,
%[smq_scale0]
, v153 row_newbcast:13
\n
"
" v_mul_f32 v154,
v18
, v154 row_newbcast:14
\n
"
" v_mul_f32 v154,
%[smq_scale0]
, v154 row_newbcast:14
\n
"
" v_mul_f32 v155,
v18
, v155 row_newbcast:15
\n
"
" v_mul_f32 v155,
%[smq_scale0]
, v155 row_newbcast:15
\n
"
" v_mul_f32 v156,
v18
, v156 row_newbcast:12
\n
"
" v_mul_f32 v156,
%[smq_scale0]
, v156 row_newbcast:12
\n
"
" v_mul_f32 v157,
v18
, v157 row_newbcast:13
\n
"
" v_mul_f32 v157,
%[smq_scale0]
, v157 row_newbcast:13
\n
"
" v_mul_f32 v158,
v18
, v158 row_newbcast:14
\n
"
" v_mul_f32 v158,
%[smq_scale0]
, v158 row_newbcast:14
\n
"
" v_mul_f32 v159,
v18
, v159 row_newbcast:15
\n
"
" v_mul_f32 v159,
%[smq_scale0]
, v159 row_newbcast:15
\n
"
" v_mul_f32 v160,
v19
, v160 row_newbcast:0
\n
"
" v_mul_f32 v160,
%[smq_scale1]
, v160 row_newbcast:0
\n
"
" v_mul_f32 v161,
v19
, v161 row_newbcast:1
\n
"
" v_mul_f32 v161,
%[smq_scale1]
, v161 row_newbcast:1
\n
"
" v_mul_f32 v162,
v19
, v162 row_newbcast:2
\n
"
" v_mul_f32 v162,
%[smq_scale1]
, v162 row_newbcast:2
\n
"
" v_mul_f32 v163,
v19
, v163 row_newbcast:3
\n
"
" v_mul_f32 v163,
%[smq_scale1]
, v163 row_newbcast:3
\n
"
" v_mul_f32 v164,
v19
, v164 row_newbcast:0
\n
"
" v_mul_f32 v164,
%[smq_scale1]
, v164 row_newbcast:0
\n
"
" v_mul_f32 v165,
v19
, v165 row_newbcast:1
\n
"
" v_mul_f32 v165,
%[smq_scale1]
, v165 row_newbcast:1
\n
"
" v_mul_f32 v166,
v19
, v166 row_newbcast:2
\n
"
" v_mul_f32 v166,
%[smq_scale1]
, v166 row_newbcast:2
\n
"
" v_mul_f32 v167,
v19
, v167 row_newbcast:3
\n
"
" v_mul_f32 v167,
%[smq_scale1]
, v167 row_newbcast:3
\n
"
" v_mul_f32 v168,
v19
, v168 row_newbcast:4
\n
"
" v_mul_f32 v168,
%[smq_scale1]
, v168 row_newbcast:4
\n
"
" v_mul_f32 v169,
v19
, v169 row_newbcast:5
\n
"
" v_mul_f32 v169,
%[smq_scale1]
, v169 row_newbcast:5
\n
"
" v_mul_f32 v170,
v19
, v170 row_newbcast:6
\n
"
" v_mul_f32 v170,
%[smq_scale1]
, v170 row_newbcast:6
\n
"
" v_mul_f32 v171,
v19
, v171 row_newbcast:7
\n
"
" v_mul_f32 v171,
%[smq_scale1]
, v171 row_newbcast:7
\n
"
" v_mul_f32 v172,
v19
, v172 row_newbcast:4
\n
"
" v_mul_f32 v172,
%[smq_scale1]
, v172 row_newbcast:4
\n
"
" v_mul_f32 v173,
v19
, v173 row_newbcast:5
\n
"
" v_mul_f32 v173,
%[smq_scale1]
, v173 row_newbcast:5
\n
"
" v_mul_f32 v174,
v19
, v174 row_newbcast:6
\n
"
" v_mul_f32 v174,
%[smq_scale1]
, v174 row_newbcast:6
\n
"
" v_mul_f32 v175,
v19
, v175 row_newbcast:7
\n
"
" v_mul_f32 v175,
%[smq_scale1]
, v175 row_newbcast:7
\n
"
" v_mul_f32 v176,
v19
, v176 row_newbcast:8
\n
"
" v_mul_f32 v176,
%[smq_scale1]
, v176 row_newbcast:8
\n
"
" v_mul_f32 v177,
v19
, v177 row_newbcast:9
\n
"
" v_mul_f32 v177,
%[smq_scale1]
, v177 row_newbcast:9
\n
"
" v_mul_f32 v178,
v19
, v178 row_newbcast:10
\n
"
" v_mul_f32 v178,
%[smq_scale1]
, v178 row_newbcast:10
\n
"
" v_mul_f32 v179,
v19
, v179 row_newbcast:11
\n
"
" v_mul_f32 v179,
%[smq_scale1]
, v179 row_newbcast:11
\n
"
" v_mul_f32 v180,
v19
, v180 row_newbcast:8
\n
"
" v_mul_f32 v180,
%[smq_scale1]
, v180 row_newbcast:8
\n
"
" v_mul_f32 v181,
v19
, v181 row_newbcast:9
\n
"
" v_mul_f32 v181,
%[smq_scale1]
, v181 row_newbcast:9
\n
"
" v_mul_f32 v182,
v19
, v182 row_newbcast:10
\n
"
" v_mul_f32 v182,
%[smq_scale1]
, v182 row_newbcast:10
\n
"
" v_mul_f32 v183,
v19
, v183 row_newbcast:11
\n
"
" v_mul_f32 v183,
%[smq_scale1]
, v183 row_newbcast:11
\n
"
" v_mul_f32 v184,
v19
, v184 row_newbcast:12
\n
"
" v_mul_f32 v184,
%[smq_scale1]
, v184 row_newbcast:12
\n
"
" v_mul_f32 v185,
v19
, v185 row_newbcast:13
\n
"
" v_mul_f32 v185,
%[smq_scale1]
, v185 row_newbcast:13
\n
"
" v_mul_f32 v186,
v19
, v186 row_newbcast:14
\n
"
" v_mul_f32 v186,
%[smq_scale1]
, v186 row_newbcast:14
\n
"
" v_mul_f32 v187,
v19
, v187 row_newbcast:15
\n
"
" v_mul_f32 v187,
%[smq_scale1]
, v187 row_newbcast:15
\n
"
" v_mul_f32 v188,
v19
, v188 row_newbcast:12
\n
"
" v_mul_f32 v188,
%[smq_scale1]
, v188 row_newbcast:12
\n
"
" v_mul_f32 v189,
v19
, v189 row_newbcast:13
\n
"
" v_mul_f32 v189,
%[smq_scale1]
, v189 row_newbcast:13
\n
"
" v_mul_f32 v190,
v19
, v190 row_newbcast:14
\n
"
" v_mul_f32 v190,
%[smq_scale1]
, v190 row_newbcast:14
\n
"
" v_mul_f32 v191,
v19
, v191 row_newbcast:15
\n
"
" v_mul_f32 v191,
%[smq_scale1]
, v191 row_newbcast:15
\n
"
" buffer_load_dword v12,
v5
, %[s_res_dq], 0 offen
\n
"
" buffer_load_dword v12,
%[v_os_dq]
, %[s_res_dq], 0 offen
\n
"
" v_mov_b32 v22, 0x358637bd
\n
"
" v_mov_b32 v22, 0x358637bd
\n
"
" v_mov_b32 v23, 0x358637bd
\n
"
" v_mov_b32 v23, 0x358637bd
\n
"
" v_max3_f32 v22, abs(v128), abs(v129), v22
\n
"
" v_max3_f32 v22, abs(v128), abs(v129), v22
\n
"
...
@@ -934,9 +937,42 @@
...
@@ -934,9 +937,42 @@
" v_lshlrev_b32 v54, 1, v54
\n
"
" v_lshlrev_b32 v54, 1, v54
\n
"
" v_add_u32 v55, v54, v55
\n
"
" v_add_u32 v55, v54, v55
\n
"
" v_lshlrev_b32 v54, 2, v55
\n
"
" v_lshlrev_b32 v54, 2, v55
\n
"
" ds_read_b64 v[128:129], v54 offset:18688
\n
"
" ds_read_b64 v[130:131], v54 offset:18816
\n
"
" ds_read_b64 v[132:133], v54 offset:19712
\n
"
" ds_read_b64 v[134:135], v54 offset:19840
\n
"
" ds_read_b64 v[136:137], v54 offset:20736
\n
"
" ds_read_b64 v[138:139], v54 offset:20864
\n
"
" ds_read_b64 v[140:141], v54 offset:21760
\n
"
" ds_read_b64 v[142:143], v54 offset:21888
\n
"
" ds_read_b64 v[144:145], v54 offset:22784
\n
"
" ds_read_b64 v[146:147], v54 offset:22912
\n
"
" ds_read_b64 v[148:149], v54 offset:23808
\n
"
" ds_read_b64 v[150:151], v54 offset:23936
\n
"
" ds_read_b64 v[152:153], v54 offset:24832
\n
"
" ds_read_b64 v[154:155], v54 offset:24960
\n
"
" ds_read_b64 v[156:157], v54 offset:25856
\n
"
" ds_read_b64 v[158:159], v54 offset:25984
\n
"
" ds_read_b64 v[160:161], v54 offset:26880
\n
"
" ds_read_b64 v[162:163], v54 offset:27008
\n
"
" ds_read_b64 v[164:165], v54 offset:27904
\n
"
" ds_read_b64 v[166:167], v54 offset:28032
\n
"
" ds_read_b64 v[168:169], v54 offset:28928
\n
"
" ds_read_b64 v[170:171], v54 offset:29056
\n
"
" ds_read_b64 v[172:173], v54 offset:29952
\n
"
" ds_read_b64 v[174:175], v54 offset:30080
\n
"
" ds_read_b64 v[176:177], v54 offset:30976
\n
"
" ds_read_b64 v[178:179], v54 offset:31104
\n
"
" ds_read_b64 v[180:181], v54 offset:32000
\n
"
" ds_read_b64 v[182:183], v54 offset:32128
\n
"
" ds_read_b64 v[184:185], v54 offset:33024
\n
"
" ds_read_b64 v[186:187], v54 offset:33152
\n
"
" ds_read_b64 v[188:189], v54 offset:34048
\n
"
" ds_read_b64 v[190:191], v54 offset:34176
\n
"
#undef _UK_MFMA_
#undef _UK_MFMA_
#undef _UK_PK_CVT_
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#undef _UK_ATOMIC_ADD_
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
View file @
6f7d1272
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
View file @
6f7d1272
...
@@ -5,36 +5,20 @@
...
@@ -5,36 +5,20 @@
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_INT8
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_INT8
# define _UK_MFMA_ "v_mfma_i32_16x16x32_i8"
# define _UK_MFMA_ "v_mfma_i32_16x16x32_i8"
#endif
#endif
# define _DEQUAN_CVT_(a0,a1,a2,a3,
b, c
) \
# define _DEQUAN_CVT_(a0,a1,a2,a3,
xq, gq,brd0,brd1,brd2,brd3
) \
" v_cvt_f32_i32
a0, a0
\n
"
\
" v_cvt_f32_i32
"
a0
", "
a0
"
\n
"
\
" v_cvt_f32_i32
a1, a1
\n
"
\
" v_cvt_f32_i32
"
a1
", "
a1
"
\n
"
\
" v_cvt_f32_i32
a2, a2
\n
"
\
" v_cvt_f32_i32
"
a2
", "
a2
"
\n
"
\
" v_cvt_f32_i32 a3
,
a3
\n
"
\
" v_cvt_f32_i32
"
a3
", "
a3
"
\n
"
\
" v_mul_f32
a0, v15, a0
\n
"
\
" v_mul_f32
"
a0
", "
xq
", "
a0
"
\n
"
\
" v_mul_f32
a1, v15, a1
\n
"
\
" v_mul_f32
"
a1
", "
xq
", "
a1
"
\n
"
\
" v_mul_f32
a2, v15,
a2
\n
"
\
" v_mul_f32
"
a2
", "
xq
", "
a2
"
\n
"
\
" v_mul_f32
a3, v15,
a3
\n
"
\
" v_mul_f32
"
a3
", "
xq
", "
a3
"
\n
"
\
" v_mul_f32
a0, v17, a0
row_newbcast:
12
\n
"
\
" v_mul_f32
"
a0
", "
gq
", "
a0
"
row_newbcast:
"
brd0
"
\n
"
\
" v_mul_f32
a1, v17,
a1 row_newbcast:1
3
\n
"
\
" v_mul_f32
"
a1
", "
gq
", "
a1
"
row_newbcast:
"
brd
1
"
\n
"
\
" v_mul_f32
a2, v17, a2
row_newbcast:
14
\n
"
\
" v_mul_f32
"
a2
", "
gq
", "
a2
"
row_newbcast:
"
brd2
"
\n
"
\
" v_mul_f32
a3, v17, a3
row_newbcast:
15
\n
"
\
" v_mul_f32
"
a3
", "
gq
", "
a3
"
row_newbcast:
"
brd3
"
\n
"
";----------------------------------------------
\n
"
" v_lshrrev_b32 v54, 4, v0
\n
"
" v_lshlrev_b32 v55, 2, v54
\n
"
" v_and_b32 v54, 15, v0
\n
"
" v_lshrrev_b32 v56, 2, v54
\n
"
" v_lshlrev_b32 v56, 6, v56
\n
"
" v_add_u32 v55, v56, v55
\n
"
" v_and_b32 v54, 3, v0
\n
"
" v_add_u32 v55, v54, v55
\n
"
" v_lshlrev_b32 v10, 2, v55
\n
"
" v_add_u32 v11, 0x00000400, v10
\n
"
" s_mul_i32 s60, %[s_wave_id], 16
\n
"
" s_mul_i32 s60, s60, 4
\n
"
" v_add_u32 v10, s60, v10
\n
"
" v_add_u32 v11, s60, v11
\n
"
" v_mov_b32 v5, v10
\n
"
";----------------------------------------------
\n
"
";----------------------------------------------
\n
"
" s_mov_b32 s57, 0x00000100
\n
"
" s_mov_b32 s57, 0x00000100
\n
"
" s_mov_b32 s58, 0x00001000
\n
"
" s_mov_b32 s58, 0x00001000
\n
"
...
@@ -53,27 +37,22 @@
...
@@ -53,27 +37,22 @@
" v_mov_b32 v52, 0x7fff0000
\n
"
" v_mov_b32 v52, 0x7fff0000
\n
"
" v_mov_b32 v53, 0x00007fff
\n
"
" v_mov_b32 v53, 0x00007fff
\n
"
" s_waitcnt 0x0000
\n
"
" s_waitcnt 0x0000
\n
"
";----------------------------------------------
\n
"
" v_lshrrev_b32 v54, 24, %[v_token_id0]
\n
"
" v_mul_i32_i24 v54, s66, v54
\n
"
" v_and_b32 v55, 0x00ffffff, %[v_token_id0]
\n
"
" v_add_u32 %[v_token_id0], v54, v55
\n
"
" v_lshrrev_b32 v54, 24, %[v_token_id1]
\n
"
" v_mul_i32_i24 v54, s66, v54
\n
"
" v_and_b32 v55, 0x00ffffff, %[v_token_id1]
\n
"
" v_add_u32 %[v_token_id1], v54, v55
\n
"
" v_lshlrev_b32 %[v_token_id0], 2, %[v_token_id0]
\n
"
" v_lshlrev_b32 %[v_token_id1], 2, %[v_token_id1]
\n
"
" buffer_load_dword v14, %[v_token_id0], %[s_res_aq], 0 offen
\n
"
" buffer_load_dword v15, %[v_token_id1], %[s_res_aq], 0 offen
\n
"
" buffer_load_dword v16, v10, %[s_res_gq], 0 offen
\n
"
" buffer_load_dword v17, v11, %[s_res_gq], 0 offen
\n
"
" buffer_load_dword v18, v10, %[s_res_smq], 0 offen
\n
"
" buffer_load_dword v19, v11, %[s_res_smq], 0 offen
\n
"
" buffer_load_dword v20, v8, s[40:43], 0 offen
\n
"
" buffer_load_dword v21, v9, s[40:43], 0 offen
\n
"
" s_mov_b32 s80, 0
\n
"
" s_mov_b32 s80, 0
\n
"
" v_lshrrev_b32 v54, 4, v0
\n
"
" v_mul_i32_i24 v3, 34, v54
\n
"
" v_and_b32 v54, 15, v0
\n
"
" v_mul_i32_i24 v55, 2, v54
\n
"
" v_add_u32 v3, v55, v3
\n
"
" s_mul_i32 s60, s7, 0x00000088
\n
"
" v_add_u32 v3, s60, v3
\n
"
" v_lshlrev_b32 v3, 2, v3
\n
"
" v_lshrrev_b32 v54, 1, v0
\n
"
" v_mul_i32_i24 v4, 34, v54
\n
"
" v_and_b32 v55, 1, v0
\n
"
" v_add_u32 v4, v55, v4
\n
"
" s_mul_i32 s60, s7, 2
\n
"
" v_add_u32 v4, s60, v4
\n
"
" v_lshlrev_b32 v4, 2, v4
\n
"
";----------------------------------------------
\n
"
";----------------------------------------------
\n
"
"; -- prefetch A0
\n
"
"; -- prefetch A0
\n
"
"s_add_u32 m0, 0, %[s_m0_init]
\n
"
"s_add_u32 m0, 0, %[s_m0_init]
\n
"
...
@@ -570,198 +549,23 @@
...
@@ -570,198 +549,23 @@
" s_branch label_start
\n
"
" s_branch label_start
\n
"
" label_end :
\n
"
" label_end :
\n
"
";----------------------------------------------
\n
"
";----------------------------------------------
\n
"
" v_cvt_f32_i32 v128, v128
\n
"
_DEQUAN_CVT_
(
"%[c0]"
,
"%[c1]"
,
"%[c2]"
,
"%[c3]"
,
"%[a_scale0]"
,
" %[gq_scale0]"
,
"0"
,
"1"
,
"2"
,
"3"
)
" v_cvt_f32_i32 v129, v129
\n
"
_DEQUAN_CVT_
(
"%[c4]"
,
"%[c5]"
,
"%[c6]"
,
"%[c7]"
,
"%[a_scale1]"
,
" %[gq_scale0]"
,
"0"
,
"1"
,
"2"
,
"3"
)
" v_cvt_f32_i32 v130, v130
\n
"
_DEQUAN_CVT_
(
"%[c8]"
,
"%[c9]"
,
"%[c10]"
,
"%[c11]"
,
"%[a_scale0]"
,
" %[gq_scale0]"
,
"4"
,
"5"
,
"6"
,
"7"
)
" v_cvt_f32_i32 v131, v131
\n
"
_DEQUAN_CVT_
(
"%[c12]"
,
"%[c13]"
,
"%[c14]"
,
"%[c15]"
,
"%[a_scale1]"
,
" %[gq_scale0]"
,
"4"
,
"5"
,
"6"
,
"7"
)
" v_mul_f32 v128, v14, v128
\n
"
_DEQUAN_CVT_
(
"%[c16]"
,
"%[c17]"
,
"%[c18]"
,
"%[c19]"
,
"%[a_scale0]"
,
" %[gq_scale0]"
,
"8"
,
"9"
,
"10"
,
"11"
)
" v_mul_f32 v129, v14, v129
\n
"
_DEQUAN_CVT_
(
"%[c20]"
,
"%[c21]"
,
"%[c22]"
,
"%[c23]"
,
"%[a_scale1]"
,
" %[gq_scale0]"
,
"8"
,
"9"
,
"10"
,
"11"
)
" v_mul_f32 v130, v14, v130
\n
"
_DEQUAN_CVT_
(
"%[c24]"
,
"%[c25]"
,
"%[c26]"
,
"%[c27]"
,
"%[a_scale0]"
,
" %[gq_scale0]"
,
"12"
,
"13"
,
"14"
,
"15"
)
" v_mul_f32 v131, v14, v131
\n
"
_DEQUAN_CVT_
(
"%[c28]"
,
"%[c29]"
,
"%[c30]"
,
"%[c31]"
,
"%[a_scale1]"
,
" %[gq_scale0]"
,
"12"
,
"13"
,
"14"
,
"15"
)
" v_mul_f32 v128, v16, v128 row_newbcast:0
\n
"
_DEQUAN_CVT_
(
"%[c32]"
,
"%[c33]"
,
"%[c34]"
,
"%[c35]"
,
"%[a_scale0]"
,
" %[gq_scale1]"
,
"0"
,
"1"
,
"2"
,
"3"
)
" v_mul_f32 v129, v16, v129 row_newbcast:1
\n
"
_DEQUAN_CVT_
(
"%[c36]"
,
"%[c37]"
,
"%[c38]"
,
"%[c39]"
,
"%[a_scale1]"
,
" %[gq_scale1]"
,
"0"
,
"1"
,
"2"
,
"3"
)
" v_mul_f32 v130, v16, v130 row_newbcast:2
\n
"
_DEQUAN_CVT_
(
"%[c40]"
,
"%[c41]"
,
"%[c42]"
,
"%[c43]"
,
"%[a_scale0]"
,
" %[gq_scale1]"
,
"4"
,
"5"
,
"6"
,
"7"
)
" v_mul_f32 v131, v16, v131 row_newbcast:3
\n
"
_DEQUAN_CVT_
(
"%[c44]"
,
"%[c45]"
,
"%[c46]"
,
"%[c47]"
,
"%[a_scale1]"
,
" %[gq_scale1]"
,
"4"
,
"5"
,
"6"
,
"7"
)
" v_cvt_f32_i32 v132, v132
\n
"
_DEQUAN_CVT_
(
"%[c48]"
,
"%[c49]"
,
"%[c50]"
,
"%[c51]"
,
"%[a_scale0]"
,
" %[gq_scale1]"
,
"8"
,
"9"
,
"10"
,
"11"
)
" v_cvt_f32_i32 v133, v133
\n
"
_DEQUAN_CVT_
(
"%[c52]"
,
"%[c53]"
,
"%[c54]"
,
"%[c55]"
,
"%[a_scale1]"
,
" %[gq_scale1]"
,
"8"
,
"9"
,
"10"
,
"11"
)
" v_cvt_f32_i32 v134, v134
\n
"
_DEQUAN_CVT_
(
"%[c56]"
,
"%[c57]"
,
"%[c58]"
,
"%[c59]"
,
"%[a_scale0]"
,
" %[gq_scale1]"
,
"12"
,
"13"
,
"14"
,
"15"
)
" v_cvt_f32_i32 v135, v135
\n
"
_DEQUAN_CVT_
(
"%[c60]"
,
"%[c61]"
,
"%[c62]"
,
"%[c63]"
,
"%[a_scale1]"
,
" %[gq_scale1]"
,
"12"
,
"13"
,
"14"
,
"15"
)
" v_mul_f32 v132, v15, v132
\n
"
" v_mul_f32 v133, v15, v133
\n
"
" v_mul_f32 v134, v15, v134
\n
"
" v_mul_f32 v135, v15, v135
\n
"
" v_mul_f32 v132, v16, v132 row_newbcast:0
\n
"
" v_mul_f32 v133, v16, v133 row_newbcast:1
\n
"
" v_mul_f32 v134, v16, v134 row_newbcast:2
\n
"
" v_mul_f32 v135, v16, v135 row_newbcast:3
\n
"
" v_cvt_f32_i32 v136, v136
\n
"
" v_cvt_f32_i32 v137, v137
\n
"
" v_cvt_f32_i32 v138, v138
\n
"
" v_cvt_f32_i32 v139, v139
\n
"
" v_mul_f32 v136, v14, v136
\n
"
" v_mul_f32 v137, v14, v137
\n
"
" v_mul_f32 v138, v14, v138
\n
"
" v_mul_f32 v139, v14, v139
\n
"
" v_mul_f32 v136, v16, v136 row_newbcast:4
\n
"
" v_mul_f32 v137, v16, v137 row_newbcast:5
\n
"
" v_mul_f32 v138, v16, v138 row_newbcast:6
\n
"
" v_mul_f32 v139, v16, v139 row_newbcast:7
\n
"
" v_cvt_f32_i32 v140, v140
\n
"
" v_cvt_f32_i32 v141, v141
\n
"
" v_cvt_f32_i32 v142, v142
\n
"
" v_cvt_f32_i32 v143, v143
\n
"
" v_mul_f32 v140, v15, v140
\n
"
" v_mul_f32 v141, v15, v141
\n
"
" v_mul_f32 v142, v15, v142
\n
"
" v_mul_f32 v143, v15, v143
\n
"
" v_mul_f32 v140, v16, v140 row_newbcast:4
\n
"
" v_mul_f32 v141, v16, v141 row_newbcast:5
\n
"
" v_mul_f32 v142, v16, v142 row_newbcast:6
\n
"
" v_mul_f32 v143, v16, v143 row_newbcast:7
\n
"
" v_cvt_f32_i32 v144, v144
\n
"
" v_cvt_f32_i32 v145, v145
\n
"
" v_cvt_f32_i32 v146, v146
\n
"
" v_cvt_f32_i32 v147, v147
\n
"
" v_mul_f32 v144, v14, v144
\n
"
" v_mul_f32 v145, v14, v145
\n
"
" v_mul_f32 v146, v14, v146
\n
"
" v_mul_f32 v147, v14, v147
\n
"
" v_mul_f32 v144, v16, v144 row_newbcast:8
\n
"
" v_mul_f32 v145, v16, v145 row_newbcast:9
\n
"
" v_mul_f32 v146, v16, v146 row_newbcast:10
\n
"
" v_mul_f32 v147, v16, v147 row_newbcast:11
\n
"
" v_cvt_f32_i32 v148, v148
\n
"
" v_cvt_f32_i32 v149, v149
\n
"
" v_cvt_f32_i32 v150, v150
\n
"
" v_cvt_f32_i32 v151, v151
\n
"
" v_mul_f32 v148, v15, v148
\n
"
" v_mul_f32 v149, v15, v149
\n
"
" v_mul_f32 v150, v15, v150
\n
"
" v_mul_f32 v151, v15, v151
\n
"
" v_mul_f32 v148, v16, v148 row_newbcast:8
\n
"
" v_mul_f32 v149, v16, v149 row_newbcast:9
\n
"
" v_mul_f32 v150, v16, v150 row_newbcast:10
\n
"
" v_mul_f32 v151, v16, v151 row_newbcast:11
\n
"
" v_cvt_f32_i32 v152, v152
\n
"
" v_cvt_f32_i32 v153, v153
\n
"
" v_cvt_f32_i32 v154, v154
\n
"
" v_cvt_f32_i32 v155, v155
\n
"
" v_mul_f32 v152, v14, v152
\n
"
" v_mul_f32 v153, v14, v153
\n
"
" v_mul_f32 v154, v14, v154
\n
"
" v_mul_f32 v155, v14, v155
\n
"
" v_mul_f32 v152, v16, v152 row_newbcast:12
\n
"
" v_mul_f32 v153, v16, v153 row_newbcast:13
\n
"
" v_mul_f32 v154, v16, v154 row_newbcast:14
\n
"
" v_mul_f32 v155, v16, v155 row_newbcast:15
\n
"
" v_cvt_f32_i32 v156, v156
\n
"
" v_cvt_f32_i32 v157, v157
\n
"
" v_cvt_f32_i32 v158, v158
\n
"
" v_cvt_f32_i32 v159, v159
\n
"
" v_mul_f32 v156, v15, v156
\n
"
" v_mul_f32 v157, v15, v157
\n
"
" v_mul_f32 v158, v15, v158
\n
"
" v_mul_f32 v159, v15, v159
\n
"
" v_mul_f32 v156, v16, v156 row_newbcast:12
\n
"
" v_mul_f32 v157, v16, v157 row_newbcast:13
\n
"
" v_mul_f32 v158, v16, v158 row_newbcast:14
\n
"
" v_mul_f32 v159, v16, v159 row_newbcast:15
\n
"
" v_cvt_f32_i32 v160, v160
\n
"
" v_cvt_f32_i32 v161, v161
\n
"
" v_cvt_f32_i32 v162, v162
\n
"
" v_cvt_f32_i32 v163, v163
\n
"
" v_mul_f32 v160, v14, v160
\n
"
" v_mul_f32 v161, v14, v161
\n
"
" v_mul_f32 v162, v14, v162
\n
"
" v_mul_f32 v163, v14, v163
\n
"
" v_mul_f32 v160, v17, v160 row_newbcast:0
\n
"
" v_mul_f32 v161, v17, v161 row_newbcast:1
\n
"
" v_mul_f32 v162, v17, v162 row_newbcast:2
\n
"
" v_mul_f32 v163, v17, v163 row_newbcast:3
\n
"
" v_cvt_f32_i32 v164, v164
\n
"
" v_cvt_f32_i32 v165, v165
\n
"
" v_cvt_f32_i32 v166, v166
\n
"
" v_cvt_f32_i32 v167, v167
\n
"
" v_mul_f32 v164, v15, v164
\n
"
" v_mul_f32 v165, v15, v165
\n
"
" v_mul_f32 v166, v15, v166
\n
"
" v_mul_f32 v167, v15, v167
\n
"
" v_mul_f32 v164, v17, v164 row_newbcast:0
\n
"
" v_mul_f32 v165, v17, v165 row_newbcast:1
\n
"
" v_mul_f32 v166, v17, v166 row_newbcast:2
\n
"
" v_mul_f32 v167, v17, v167 row_newbcast:3
\n
"
" v_cvt_f32_i32 v168, v168
\n
"
" v_cvt_f32_i32 v169, v169
\n
"
" v_cvt_f32_i32 v170, v170
\n
"
" v_cvt_f32_i32 v171, v171
\n
"
" v_mul_f32 v168, v14, v168
\n
"
" v_mul_f32 v169, v14, v169
\n
"
" v_mul_f32 v170, v14, v170
\n
"
" v_mul_f32 v171, v14, v171
\n
"
" v_mul_f32 v168, v17, v168 row_newbcast:4
\n
"
" v_mul_f32 v169, v17, v169 row_newbcast:5
\n
"
" v_mul_f32 v170, v17, v170 row_newbcast:6
\n
"
" v_mul_f32 v171, v17, v171 row_newbcast:7
\n
"
" v_cvt_f32_i32 v172, v172
\n
"
" v_cvt_f32_i32 v173, v173
\n
"
" v_cvt_f32_i32 v174, v174
\n
"
" v_cvt_f32_i32 v175, v175
\n
"
" v_mul_f32 v172, v15, v172
\n
"
" v_mul_f32 v173, v15, v173
\n
"
" v_mul_f32 v174, v15, v174
\n
"
" v_mul_f32 v175, v15, v175
\n
"
" v_mul_f32 v172, v17, v172 row_newbcast:4
\n
"
" v_mul_f32 v173, v17, v173 row_newbcast:5
\n
"
" v_mul_f32 v174, v17, v174 row_newbcast:6
\n
"
" v_mul_f32 v175, v17, v175 row_newbcast:7
\n
"
" v_cvt_f32_i32 v176, v176
\n
"
" v_cvt_f32_i32 v177, v177
\n
"
" v_cvt_f32_i32 v178, v178
\n
"
" v_cvt_f32_i32 v179, v179
\n
"
" v_mul_f32 v176, v14, v176
\n
"
" v_mul_f32 v177, v14, v177
\n
"
" v_mul_f32 v178, v14, v178
\n
"
" v_mul_f32 v179, v14, v179
\n
"
" v_mul_f32 v176, v17, v176 row_newbcast:8
\n
"
" v_mul_f32 v177, v17, v177 row_newbcast:9
\n
"
" v_mul_f32 v178, v17, v178 row_newbcast:10
\n
"
" v_mul_f32 v179, v17, v179 row_newbcast:11
\n
"
" v_cvt_f32_i32 v180, v180
\n
"
" v_cvt_f32_i32 v181, v181
\n
"
" v_cvt_f32_i32 v182, v182
\n
"
" v_cvt_f32_i32 v183, v183
\n
"
" v_mul_f32 v180, v15, v180
\n
"
" v_mul_f32 v181, v15, v181
\n
"
" v_mul_f32 v182, v15, v182
\n
"
" v_mul_f32 v183, v15, v183
\n
"
" v_mul_f32 v180, v17, v180 row_newbcast:8
\n
"
" v_mul_f32 v181, v17, v181 row_newbcast:9
\n
"
" v_mul_f32 v182, v17, v182 row_newbcast:10
\n
"
" v_mul_f32 v183, v17, v183 row_newbcast:11
\n
"
" v_cvt_f32_i32 v184, v184
\n
"
" v_cvt_f32_i32 v185, v185
\n
"
" v_cvt_f32_i32 v186, v186
\n
"
" v_cvt_f32_i32 v187, v187
\n
"
" v_mul_f32 v184, v14, v184
\n
"
" v_mul_f32 v185, v14, v185
\n
"
" v_mul_f32 v186, v14, v186
\n
"
" v_mul_f32 v187, v14, v187
\n
"
" v_mul_f32 v184, v17, v184 row_newbcast:12
\n
"
" v_mul_f32 v185, v17, v185 row_newbcast:13
\n
"
" v_mul_f32 v186, v17, v186 row_newbcast:14
\n
"
" v_mul_f32 v187, v17, v187 row_newbcast:15
\n
"
" v_cvt_f32_i32 v188, v188
\n
"
" v_cvt_f32_i32 v189, v189
\n
"
" v_cvt_f32_i32 v190, v190
\n
"
" v_cvt_f32_i32 v191, v191
\n
"
" v_mul_f32 v188, v15, v188
\n
"
" v_mul_f32 v189, v15, v189
\n
"
" v_mul_f32 v190, v15, v190
\n
"
" v_mul_f32 v191, v15, v191
\n
"
" v_mul_f32 v188, v17, v188 row_newbcast:12
\n
"
" v_mul_f32 v189, v17, v189 row_newbcast:13
\n
"
" v_mul_f32 v190, v17, v190 row_newbcast:14
\n
"
" v_mul_f32 v191, v17, v191 row_newbcast:15
\n
"
#undef _UK_MFMA_
#undef _UK_MFMA_
#undef _DEQUAN_CVT_
#undef _DEQUAN_CVT_
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
View file @
6f7d1272
...
@@ -186,6 +186,50 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -186,6 +186,50 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return
coords
;
return
coords
;
}
}
// TODO: this row id is before shuffle atomic, need use acc distribution
//this calculation shared by G and SMQ
CK_TILE_DEVICE
auto
GetColCoords_GQSMQ
(
index_t
base_offset
)
{
constexpr
index_t
MLanes
=
BlockShape
::
Warp_M1
;
constexpr
index_t
Repeat_N
=
2
;
//different,this load is partitioned along N
// auto h_id = threadIdx.x / MLanes ;
// auto r_id = threadIdx.x & 0xffff;
// auto p_id = r_id/4;
// auto q_is = threadIdx.x & 0x3;
array
<
index_t
,
Repeat_N
>
coords
;
static_for
<
0
,
Repeat_N
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
(
threadIdx
.
x
/
MLanes
)
*
4
+
(
threadIdx
.
x
&
0xffff
)
/
4
*
64
+
q_id
+
i
*
256
;
});
return
coords
;
}
//this calculation shared by G and SMQ
CK_TILE_DEVICE
auto
GetGQScale
(
const
COL_IDS
coords
,
const
GScaleDataType
*
g_scale_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
GScaleDataType
,
n_size
>
g_scale_value
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
g_scale_value
.
at
(
i
)
=
g_scale_ptr
[
coords
[
i
]];
});
return
g_scale_value
;
}
CK_TILE_DEVICE
auto
GetSMQScale
(
const
COL_IDS
coords
,
const
YSmoothScaleDataType
*
y_scale_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
YSmoothScaleDataType
,
n_size
>
y_scale_value
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
y_scale_value
.
at
(
i
)
=
y_scale_ptr
[
coords
[
i
]];
});
return
y_scale_value
;
}
template
<
typename
Karg
>
template
<
typename
Karg
>
CK_TILE_DEVICE
auto
operator
()(
const
Karg
&
kargs
,
CK_TILE_DEVICE
auto
operator
()(
const
Karg
&
kargs
,
...
@@ -230,12 +274,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -230,12 +274,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return
(
row_ids_a
[
i
])
&
0xffffff
;
return
(
row_ids_a
[
i
])
&
0xffffff
;
},
},
number
<
row_ids_a
.
size
()
>
{});
number
<
row_ids_a
.
size
()
>
{});
// auto token_id_mma = generate_tuple(
// [&](auto i) {
// return (row_ids_a_mma[i]) &0xffffff;
// },
// number<row_ids_a_mma.size()>{});
//addr in fact
auto
a_coords
=
generate_tuple
(
auto
a_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
return
((
row_ids_a
[
i
])
&
0xffffff
)
*
kargs
.
stride_token
+
return
((
row_ids_a
[
i
])
&
0xffffff
)
*
kargs
.
stride_token
+
...
@@ -306,7 +344,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -306,7 +344,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto
smq_win
=
[
&
]()
{
auto
smq_win
=
[
&
]()
{
const
YSmoothScaleDataType
*
smq_ptr
=
reinterpret_cast
<
const
YSmoothScaleDataType
*>
(
kargs
.
y_smooth_scale_ptr
)
+
const
YSmoothScaleDataType
*
smq_ptr
=
reinterpret_cast
<
const
YSmoothScaleDataType
*>
(
kargs
.
y_smooth_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
smq_scale_expert_stride_0
+
static_cast
<
long_index_t
>
(
expert_id
)
*
smq_scale_expert_stride_0
+
intermediate_tile_id
*
BlockShape
::
Block_
N0
;
intermediate_tile_id
*
BlockShape
::
Block_
K1
;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto
smq_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
smq_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
smq_ptr
,
smq_ptr
,
...
@@ -346,15 +384,15 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -346,15 +384,15 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto
d_res
=
d_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
auto
d_res
=
d_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
//////gq
//////gq
auto
dq_win
=
[
&
]()
{
auto
dq_win
=
[
&
]()
{
const
DScaleDataType
*
g
_ptr
=
reinterpret_cast
<
const
DScaleDataType
*>
(
kargs
.
d_scale_ptr
)
+
const
DScaleDataType
*
dq
_ptr
=
reinterpret_cast
<
const
DScaleDataType
*>
(
kargs
.
d_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
d_scale_expert_stride_1
;
static_cast
<
long_index_t
>
(
expert_id
)
*
d_scale_expert_stride_1
;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr)//remember to add expert_id as expert_idx
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr)//remember to add expert_id as expert_idx
auto
g
_view_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
auto
dq
_view_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
g
_ptr
,
dq
_ptr
,
make_tuple
(
kargs
.
hidden_size
),
make_tuple
(
kargs
.
hidden_size
),
number
<
1
>
{});
number
<
1
>
{});
return
g
_view_
;
return
dq
_view_
;
}();
}();
auto
dq_res
=
dq_win
.
get_buffer_view
().
cached_buf_res_
;
auto
dq_res
=
dq_win
.
get_buffer_view
().
cached_buf_res_
;
...
@@ -400,15 +438,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -400,15 +438,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
generate_tuple
([
&
](
auto
i
)
{
return
cmp_lt_to_exec
(
token_id
[
i
],
kargs
.
num_tokens
);
},
generate_tuple
([
&
](
auto
i
)
{
return
cmp_lt_to_exec
(
token_id
[
i
],
kargs
.
num_tokens
);
},
number
<
row_ids_a
.
size
()
>
{});
number
<
row_ids_a
.
size
()
>
{});
// auto bridge_sst_win = [&]() {
// constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
// constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
// return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
// reinterpret_cast<YDataType*>(smem), desc_),
// desc_.get_lengths(),
// {0, 0},
// dist_);
// }();
auto
o_res
=
auto
o_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
),
make_wave_buffer_resource
(
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ODataType
));
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ODataType
));
...
@@ -417,16 +446,17 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -417,16 +446,17 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto
w_scale
=
GetWeightScale
(
auto
w_scale
=
GetWeightScale
(
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
auto
a_scale
=
GetAScale
(
auto
a_scale
=
GetAScale
(
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
a_scale_ptr
));
row_ids_a_mma
,
reinterpret_cast
<
const
AScaleDataType
*>
(
kargs
.
a_scale_ptr
));
auto
gqsmq_coords
=
GetColCoords_GQSMQ
(
intermediated_tile_id
*
BlockShape
::
Block_K1
);
auto
dq_coords
=
gqsmq_coords
[
0
];
//only one for this tiling
auto
gq_scale
=
GetGQScale
(
gqsmq_coords
,
reinterpret_cast
<
const
GScaleDataType
*>
(
kargs
.
g_scale_ptr
+
static_cast
<
long_index_t
>
(
expert_id
)
*
shared_intermediate_size_0
));
auto
smq_scale
=
GetSMQScale
(
gqsmq_coords
,
reinterpret_cast
<
const
YSmoothScaleDataType
*>
(
kargs
.
y_smooth_scale_ptr
+
static_cast
<
long_index_t
>
(
expert_id
)
*
shared_intermediate_size_0
));
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
// auto acc_0= uk_0(
// auto acc_0= uk_0(
uk_0
(
uk_0
(
a_scale
,
row_ids_a_mma
,
//fake token id, 2D index for X scale
gq_scale
,
a_scale
,
dq_res
,
gq_res
,
smq_res
,
a_res
,
a_res
,
a_coords
,
a_coords
,
g_res
,
g_res
,
...
@@ -457,6 +487,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -457,6 +487,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto
uk_1
=
Policy
::
template
GetUK_1
<
Problem
>();
auto
uk_1
=
Policy
::
template
GetUK_1
<
Problem
>();
uk_1
(
dq_res
,
uk_1
(
dq_res
,
d_res
,
d_res
,
dq_coords
,
d_coords
,
d_coords
,
o_res
,
o_res
,
o_coords
,
o_coords
,
...
@@ -464,6 +495,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
...
@@ -464,6 +495,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
smem
,
smem
,
kargs
.
hidden_size
,
// total n number
kargs
.
hidden_size
,
// total n number
w_scale
,
w_scale
,
smq_scale
,
BlockShape
::
Block_N1
,
BlockShape
::
Block_N1
,
shared_intermediate_size_1
*
BlockShape
::
Block_N1
-
kr_1
*
BlockShape
::
Block_W1
,
// along N
shared_intermediate_size_1
*
BlockShape
::
Block_N1
-
kr_1
*
BlockShape
::
Block_W1
,
// along N
kr_1
*
BlockShape
::
Block_W1
,
kr_1
*
BlockShape
::
Block_W1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment