Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9a46c0e7
Commit
9a46c0e7
authored
Jan 04, 2025
by
shengnxu
Browse files
move a scale out inline
parent
26d84960
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
390 additions
and
420 deletions
+390
-420
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
...ile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
+1
-1
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
...ps/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
+16
-33
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
...flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
+52
-60
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
...ck/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
+5
-5
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
+34
-39
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
...uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
+107
-107
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
...lock/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
+140
-166
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
+34
-8
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+1
-1
No files found.
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
View file @
9a46c0e7
...
...
@@ -5,7 +5,7 @@
#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include "fused_moegemm_api.cpp"
//
#include "fused_moegemm_api.cpp"
#include <iostream>
template
<
ck_tile
::
index_t
...
Is
>
...
...
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
View file @
9a46c0e7
...
...
@@ -264,6 +264,7 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
static_assert
(
ACoords
::
size
()
==
Block_M
*
Block_K
/
BlockSize
/
4
/*2x per dword*/
);
// 8
static_assert
(
BCoords
::
size
()
==
Repeat_N
);
static_assert
(
AToken_id
::
size
()
==
Repeat_M
);
static_assert
(
Ascale
::
size
()
==
Repeat_M
);
auto
a_sst
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
...
...
@@ -451,30 +452,18 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
[
v_token_id0
]
"+v"
(
temp0
),
[
v_token_id1
]
"+v"
(
temp1
),
[
s_mem_
]
"+r"
(
smem
)
:
[
s_res_aq0
]
"s"
(
res_aq
[
0
]),
[
s_res_aq1
]
"s"
(
res_aq
[
1
]),
[
s_res_aq2
]
"s"
(
res_aq
[
2
]),
[
s_res_aq3
]
"s"
(
res_aq
[
3
]),
[
s_res_dq0
]
"s"
(
res_dq
[
0
]),
[
s_res_dq1
]
"s"
(
res_dq
[
1
]),
[
s_res_dq2
]
"s"
(
res_dq
[
2
]),
[
s_res_dq3
]
"s"
(
res_dq
[
3
]),
[
s_res_gq0
]
"s"
(
res_gq
[
0
]),
[
s_res_gq1
]
"s"
(
res_gq
[
1
]),
[
s_res_gq2
]
"s"
(
res_gq
[
2
]),
[
s_res_gq3
]
"s"
(
res_gq
[
3
]),
[
s_res_smq0
]
"s"
(
res_smq
[
0
]),
[
s_res_smq1
]
"s"
(
res_smq
[
1
]),
[
s_res_smq2
]
"s"
(
res_smq
[
2
]),
[
s_res_smq3
]
"s"
(
res_smq
[
3
]),
[
s_res_a0
]
"s"
(
res_a
[
0
]),
[
s_res_a1
]
"s"
(
res_a
[
1
]),
[
s_res_a2
]
"s"
(
res_a
[
2
]),
[
s_res_a3
]
"s"
(
res_a
[
3
]),
[
s_res_b0
]
"s"
(
res_b
[
0
]),
[
s_res_b1
]
"s"
(
res_b
[
1
]),
[
s_res_b2
]
"s"
(
res_b
[
2
]),
[
s_res_b3
]
"s"
(
res_b
[
3
]),
:
[
s_res_aq
]
"s"
(
res_aq
),
[
s_res_dq
]
"s"
(
res_dq
),
[
s_res_gq
]
"s"
(
res_gq
),
[
s_res_smq
]
"s"
(
res_smq
),
[
s_res_a
]
"s"
(
res_a
),
// [s_res_a1]"s"(res_a[1]),
// [s_res_a2]"s"(res_a[2]),
// [s_res_a3]"s"(res_a[3]),
[
s_res_b
]
"s"
(
res_b
),
// [s_res_b1]"s"(res_b[1]),
// [s_res_b2]"s"(res_b[2]),
// [s_res_b3]"s"(res_b[3]),
[
v_os_a0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
0
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
1
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
2
>
{}]
*
sizeof
(
ADataType
))),
...
...
@@ -539,21 +528,15 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s6"
,
"s7"
,
"s8"
,
"s9"
,
"s10"
,
"s11"
,
"s12"
,
"s13"
,
"s14"
,
"s15"
,
"s16"
,
"s17"
,
"s18"
,
"s19"
,
"s20"
,
"s21"
,
"s22"
,
"s23"
,
"s24"
,
"s25"
,
"s26"
,
"s27"
,
"s28"
,
"s29"
,
"s30"
,
"s31"
,
"s32"
,
"s33"
,
"s34"
,
"s35"
,
"s36"
,
"s37"
,
"s38"
,
"s39"
,
"s40"
,
"s41"
,
"s42"
,
"s43"
,
"s44"
,
"s45"
,
"s6"
,
"s7"
,
"s40"
,
"s41"
,
"s42"
,
"s43"
,
"s44"
,
"s45"
,
"s46"
,
"s47"
,
"s48"
,
"s49"
,
"s50"
,
"s51"
,
"s52"
,
"s53"
,
"s54"
,
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"v28"
,
"v29"
,
"v30"
,
"v31"
,
"v32"
,
"v33"
,
"v34"
,
"v35"
,
"v36"
,
"v37"
,
"v38"
,
"v39"
,
"v40"
,
"v41"
,
"v42"
,
"v43"
,
"v44"
,
"v45"
,
"v46"
,
"v47"
,
"v48"
,
"v49"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v56"
,
"v57"
,
"v58"
,
"v59"
,
"v60"
,
"v61"
,
"v62"
,
"v63"
,
"v64"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v56"
,
"v57"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
View file @
9a46c0e7
...
...
@@ -78,21 +78,23 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template
<
typename
BRes
,
template
<
typename
DQRes
,
typename
BRes
,
typename
BCoords
,
typename
ORes
,
typename
OCoords
,
typename
OFlags
,
typename
ScaleTensor
>
typename
OFlags
>
//
typename ScaleTensor>
CK_TILE_DEVICE
auto
operator
()(
const
BRes
&
res_b
,
operator
()(
const
DQRes
&
res_dq
,
const
BRes
&
res_b
,
const
BCoords
&
cached_coords_b
,
const
ORes
&
res_o
,
const
OCoords
&
cached_coords_o
,
const
OFlags
&
o_flags
,
// this should be in sgpr
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
n
,
// loop along n dim
const
ScaleTensor
&
scale_
,
//
const ScaleTensor& scale_,
index_t
tile_offset_dq
,
index_t
tile_offset_b
,
// stride b is fixed to blockKr * blockW, but still can adjust
index_t
tile_offset_half_b
,
//splited load alone K in to 2 part
...
...
@@ -106,11 +108,11 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
const
index_t
tile_stride_o_bytes
=
tile_offset_o
*
sizeof
(
ODataType
);
const
index_t
tile_stride_dq_bytes
=
tile_offset_dq
*
sizeof
(
DScaleDataType
);
static_assert
(
ScaleTensor
::
size
()
==
2
);
float
s0
=
scale_
[
number
<
0
>
{}];
float
s1
=
scale_
[
number
<
1
>
{}];
//
static_assert(ScaleTensor::size() == 2);
//
float s0 = scale_[number<0>{}];
//
float s1 = scale_[number<1>{}];
index_t
loop_cnt
=
n
/
Block_N
;
index_t
loop_cnt
=
n
;
// register float v_c0 asm("v64");
// register float v_c1 asm("v65");
...
...
@@ -144,15 +146,15 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// register float v_c29 asm("v93");
// register float v_c30 asm("v94");
// register float v_c31 asm("v95");
int32_t
nan_hi
=
0x7fff0000
;
int32_t
nan_lo
=
0x00007fff
;
//
int32_t nan_hi = 0x7fff0000;
//
int32_t nan_lo = 0x00007fff;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int
lane_id
=
threadIdx
.
x
%
64
;
int
sld_y_os
=
(
lane_id
%
16
)
*
4
+
(
lane_id
/
16
)
*
128
;
sld_y_os
*=
2
;
//
int lane_id = threadIdx.x % 64;
//
int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
//
sld_y_os *= 2;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
...
...
@@ -161,15 +163,15 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int
sfl_sst
=
(
threadIdx
.
x
%
16
*
4
)
+
(
threadIdx
.
x
/
16
)
*
(
64
+
4
);
sfl_sst
*=
2
;
//
int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
//
sfl_sst *= 2;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int
sfl_sld
=
(
lane_id
%
2
)
*
2
+
(
lane_id
/
2
)
*
(
64
+
4
)
+
(
threadIdx
.
x
/
64
)
*
4
;
sfl_sld
*=
2
;
//
int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
//
sfl_sld *= 2;
// B nr->kr
// clang-format off
...
...
@@ -214,18 +216,19 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [c30]"+v"(v_c30),
// [c31]"+v"(v_c31)
:
[
sld_a_base
]
"n"
(
0
),
[
shfl_base
]
"n"
(
0
),
[
v_sld_y_os
]
"v"
(
sld_y_os
),
[
v_sfl_sld
]
"v"
(
sfl_sld
),
[
v_sfl_sst
]
"v"
(
sfl_sst
),
// [shfl_base]"n"(0),
// [v_sld_y_os]"v"(sld_y_os),
// [v_sfl_sld]"v"(sfl_sld),
// [v_sfl_sst]"v"(sfl_sst),
[
s_res_dq
]
"s"
(
res_dq
),
[
s_res_o0
]
"s"
(
res_o
[
0
]),
[
s_res_o1
]
"s"
(
res_o
[
1
]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[
s_res_
b0
]
"s"
(
res_b
[
0
]
),
[
s_res_b1
]
"s"
(
res_b
[
1
]),
[
s_res_b2
]
"s"
(
res_b
[
2
]),
[
s_res_b3
]
"s"
(
res_b
[
3
]),
[
s_res_
d
]
"s"
(
res_b
),
//
[s_res_b1]"s"(res_b[1]),
//
[s_res_b2]"s"(res_b[2]),
//
[s_res_b3]"s"(res_b[3]),
[
v_os_o0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
0
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
1
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
2
>
{}]
*
sizeof
(
ODataType
))),
...
...
@@ -242,10 +245,10 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[
s_tile_os_b_half
]
"s"
(
tile_offset_half_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
s_tile_os_dq
]
"s"
(
tile_stride_dq_bytes
),
[
scale_0
]
"v"
(
s0
),
[
scale_1
]
"v"
(
s1
),
[
v_nan_lo
]
"v"
(
nan_lo
),
[
v_nan_hi
]
"v"
(
nan_hi
),
//
[scale_0]"v"(s0),
//
[scale_1]"v"(s1),
//
[v_nan_lo]"v"(nan_lo),
//
[v_nan_hi]"v"(nan_hi),
[
s_execflag_0
]
"s"
(
o_flags
[
number
<
0
>
{}]),
[
s_execflag_1
]
"s"
(
o_flags
[
number
<
1
>
{}]),
[
s_execflag_2
]
"s"
(
o_flags
[
number
<
2
>
{}]),
...
...
@@ -285,21 +288,15 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s6"
,
"s7"
,
"s8"
,
"s9"
,
"s10"
,
"s11"
,
"s12"
,
"s13"
,
"s14"
,
"s15"
,
"s16"
,
"s17"
,
"s18"
,
"s19"
,
"s20"
,
"s21"
,
"s22"
,
"s23"
,
"s24"
,
"s25"
,
"s26"
,
"s27"
,
"s28"
,
"s29"
,
"s30"
,
"s31"
,
"s32"
,
"s33"
,
"s34"
,
"s35"
,
"s36"
,
"s37"
,
"s38"
,
"s39"
,
"s40"
,
"s41"
,
"s42"
,
"s43"
,
"s44"
,
"s45"
,
"s6"
,
"s7"
,
"s40"
,
"s41"
,
"s42"
,
"s43"
,
"s44"
,
"s45"
,
"s46"
,
"s47"
,
"s48"
,
"s49"
,
"s50"
,
"s51"
,
"s52"
,
"s53"
,
"s54"
,
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"v28"
,
"v29"
,
"v30"
,
"v31"
,
"v32"
,
"v33"
,
"v34"
,
"v35"
,
"v36"
,
"v37"
,
"v38"
,
"v39"
,
"v40"
,
"v41"
,
"v42"
,
"v43"
,
"v44"
,
"v45"
,
"v46"
,
"v47"
,
"v48"
,
"v49"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v56"
,
"v57"
,
"v58"
,
"v59"
,
"v60"
,
"v61"
,
"v62"
,
"v63"
,
"v64"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v56"
,
"v57"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
...
...
@@ -364,18 +361,19 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
// [c30]"+v"(v_c30),
// [c31]"+v"(v_c31)
:
[
sld_a_base
]
"n"
(
0
),
[
shfl_base
]
"n"
(
0
),
[
v_sld_y_os
]
"v"
(
sld_y_os
),
[
v_sfl_sld
]
"v"
(
sfl_sld
),
[
v_sfl_sst
]
"v"
(
sfl_sst
),
// [shfl_base]"n"(0),
// [v_sld_y_os]"v"(sld_y_os),
// [v_sfl_sld]"v"(sfl_sld),
// [v_sfl_sst]"v"(sfl_sst),
[
s_res_dq
]
"s"
(
res_dq
),
[
s_res_o0
]
"s"
(
res_o
[
0
]),
[
s_res_o1
]
"s"
(
res_o
[
1
]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[
s_res_
b0
]
"s"
(
res_b
[
0
]
),
[
s_res_b1
]
"s"
(
res_b
[
1
]),
[
s_res_b2
]
"s"
(
res_b
[
2
]),
[
s_res_b3
]
"s"
(
res_b
[
3
]),
[
s_res_
d
]
"s"
(
res_b
),
//
[s_res_b1]"s"(res_b[1]),
//
[s_res_b2]"s"(res_b[2]),
//
[s_res_b3]"s"(res_b[3]),
[
v_os_o0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
0
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
1
>
{}]
*
sizeof
(
ODataType
))),
[
v_os_o2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_o
[
number
<
2
>
{}]
*
sizeof
(
ODataType
))),
...
...
@@ -392,10 +390,10 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[
s_tile_os_b_half
]
"s"
(
tile_offset_half_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
s_tile_os_dq
]
"s"
(
tile_stride_dq_bytes
),
[
scale_0
]
"v"
(
s0
),
[
scale_1
]
"v"
(
s1
),
[
v_nan_lo
]
"v"
(
nan_lo
),
[
v_nan_hi
]
"v"
(
nan_hi
),
//
[scale_0]"v"(s0),
//
[scale_1]"v"(s1),
//
[v_nan_lo]"v"(nan_lo),
//
[v_nan_hi]"v"(nan_hi),
[
s_execflag_0
]
"s"
(
o_flags
[
number
<
0
>
{}]),
[
s_execflag_1
]
"s"
(
o_flags
[
number
<
1
>
{}]),
[
s_execflag_2
]
"s"
(
o_flags
[
number
<
2
>
{}]),
...
...
@@ -435,21 +433,15 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s6"
,
"s7"
,
"s8"
,
"s9"
,
"s10"
,
"s11"
,
"s12"
,
"s13"
,
"s14"
,
"s15"
,
"s16"
,
"s17"
,
"s18"
,
"s19"
,
"s20"
,
"s21"
,
"s22"
,
"s23"
,
"s24"
,
"s25"
,
"s26"
,
"s27"
,
"s28"
,
"s29"
,
"s30"
,
"s31"
,
"s32"
,
"s33"
,
"s34"
,
"s35"
,
"s36"
,
"s37"
,
"s38"
,
"s39"
,
"s40"
,
"s41"
,
"s42"
,
"s43"
,
"s44"
,
"s45"
,
"s6"
,
"s7"
,
"s40"
,
"s41"
,
"s42"
,
"s43"
,
"s44"
,
"s45"
,
"s46"
,
"s47"
,
"s48"
,
"s49"
,
"s50"
,
"s51"
,
"s52"
,
"s53"
,
"s54"
,
"s55"
,
"s56"
,
"s57"
,
"s58"
,
"s59"
,
"s60"
,
"s61"
,
"s62"
,
"s63"
,
"s64"
,
"s65"
,
"s66"
,
"s67"
,
"s68"
,
"s69"
,
"s70"
,
"s71"
,
"s72"
,
"s73"
,
"s74"
,
"s75"
,
"s76"
,
"s77"
,
"s78"
,
"s79"
,
"s80"
,
// s86 as tmp
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"v28"
,
"v29"
,
"v30"
,
"v31"
,
"v32"
,
"v33"
,
"v34"
,
"v35"
,
"v36"
,
"v37"
,
"v38"
,
"v39"
,
"v40"
,
"v41"
,
"v42"
,
"v43"
,
"v44"
,
"v45"
,
"v46"
,
"v47"
,
"v48"
,
"v49"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v56"
,
"v57"
,
"v58"
,
"v59"
,
"v60"
,
"v61"
,
"v62"
,
"v63"
,
"v64"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v50"
,
"v51"
,
"v52"
,
"v53"
,
"v54"
,
"v55"
,
"v56"
,
"v57"
,
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
...
...
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
View file @
9a46c0e7
...
...
@@ -160,7 +160,7 @@
" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:23168
\n
"
" s_mov_b32 s80, 0
\n
"
" s_waitcnt vmcnt(24)
\n
"
"
label_0AA6:
\n
"
"
L_start%=:
\n
"
" s_waitcnt vmcnt(30) & lgkmcnt(0)
\n
"
" s_barrier
\n
"
_UK_MFMA_
" [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0
\n
"
...
...
@@ -398,7 +398,7 @@ _UK_PK_CVT_("%[c12]", "%[c13]", "%[c6]")
_UK_PK_CVT_
(
"%[c14]"
,
"%[c15]"
,
"%[c7]"
)
" s_addk_i32 s80, 0x0080
\n
"
" s_cmp_lt_i32 s80, %[s_loop_cnt]
\n
"
" s_cbranch_scc0
label_0EC1
\n
"
" s_cbranch_scc0
L_end%=
\n
"
" s_waitcnt vmcnt(30) & lgkmcnt(0)
\n
"
" s_barrier
\n
"
_UK_MFMA_
" [%[c16], %[c17], %[c18], %[c19]], acc[128:129], v[128:129], 0
\n
"
...
...
@@ -636,9 +636,9 @@ _UK_PK_CVT_("%[c28]", "%[c29]", "%[c22]")
_UK_PK_CVT_
(
"%[c30]"
,
"%[c31]"
,
"%[c23]"
)
" s_addk_i32 s80, 0x0080
\n
"
" s_cmp_lt_i32 s80, %[s_loop_cnt]
\n
"
" s_cbranch_scc0
label_0EC1
\n
"
" s_branch
label_0AA6
\n
"
"
label_0EC1:
\n
"
" s_cbranch_scc0
L_end%=
\n
"
" s_branch
L_start%=
\n
"
"
L_end%=:
\n
"
" s_waitcnt lgkmcnt(0)
\n
"
" s_barrier
\n
"
" ds_read_b32 v10, %[v_sfl_sld] offset:16640
\n
"
...
...
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_1.inc
View file @
9a46c0e7
...
...
@@ -27,14 +27,8 @@
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
" s_mov_b32 s8, %[s_res_o0]
\n
"
" s_mov_b32 s9, %[s_res_o1]
\n
"
" s_mov_b32 s12, %[s_res_b0]
\n
"
" s_mov_b32 s13, %[s_res_b1]
\n
"
" s_mov_b32 s14, %[s_res_b2]
\n
"
" s_mov_b32 s15, %[s_res_b3]
\n
"
" s_waitcnt vmcnt(24)
\n
"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0],
s[12:15
], 0 offen
\n
"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0],
%[s_res_d
], 0 offen
\n
"
" v_mul_f32 v54, v128, v128
\n
"
" v_mul_f32 v55, v129, v129
\n
"
" v_mul_f32 v56, v130, v130
\n
"
...
...
@@ -55,7 +49,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0],
s[12:15
], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0],
%[s_res_d
], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -68,7 +62,7 @@
" v_mul_f32 v129, v129, v55
\n
"
" v_mul_f32 v130, v130, v56
\n
"
" v_mul_f32 v131, v131, v57
\n
"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0],
s[12:15
], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0],
%[s_res_d
], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v132, v132
\n
"
" v_mul_f32 v55, v133, v133
\n
"
" v_mul_f32 v56, v134, v134
\n
"
...
...
@@ -89,7 +83,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0],
s[12:15
], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0],
%[s_res_d
], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -102,7 +96,7 @@
" v_mul_f32 v133, v133, v55
\n
"
" v_mul_f32 v134, v134, v56
\n
"
" v_mul_f32 v135, v135, v57
\n
"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1],
s[12:15
], 0 offen
\n
"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1],
%[s_res_d
], 0 offen
\n
"
" v_mul_f32 v54, v136, v136
\n
"
" v_mul_f32 v55, v137, v137
\n
"
" v_mul_f32 v56, v138, v138
\n
"
...
...
@@ -123,7 +117,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1],
s[12:15
], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1],
%[s_res_d
], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -136,7 +130,7 @@
" v_mul_f32 v137, v137, v55
\n
"
" v_mul_f32 v138, v138, v56
\n
"
" v_mul_f32 v139, v139, v57
\n
"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1],
s[12:15
], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1],
%[s_res_d
], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v140, v140
\n
"
" v_mul_f32 v55, v141, v141
\n
"
" v_mul_f32 v56, v142, v142
\n
"
...
...
@@ -157,7 +151,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1],
s[12:15
], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1],
%[s_res_d
], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -171,7 +165,7 @@
" v_mul_f32 v142, v142, v56
\n
"
" v_mul_f32 v143, v143, v57
\n
"
" s_waitcnt vmcnt(24)
\n
"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2],
s[12:15
], 0 offen
\n
"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2],
%[s_res_d
], 0 offen
\n
"
" v_mul_f32 v54, v144, v144
\n
"
" v_mul_f32 v55, v145, v145
\n
"
" v_mul_f32 v56, v146, v146
\n
"
...
...
@@ -192,7 +186,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2],
s[12:15
], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2],
%[s_res_d
], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -205,7 +199,7 @@
" v_mul_f32 v145, v145, v55
\n
"
" v_mul_f32 v146, v146, v56
\n
"
" v_mul_f32 v147, v147, v57
\n
"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2],
s[12:15
], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2],
%[s_res_d
], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v148, v148
\n
"
" v_mul_f32 v55, v149, v149
\n
"
" v_mul_f32 v56, v150, v150
\n
"
...
...
@@ -226,7 +220,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2],
s[12:15
], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2],
%[s_res_d
], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -239,7 +233,7 @@
" v_mul_f32 v149, v149, v55
\n
"
" v_mul_f32 v150, v150, v56
\n
"
" v_mul_f32 v151, v151, v57
\n
"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3],
s[12:15
], 0 offen
\n
"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3],
%[s_res_d
], 0 offen
\n
"
" v_mul_f32 v54, v152, v152
\n
"
" v_mul_f32 v55, v153, v153
\n
"
" v_mul_f32 v56, v154, v154
\n
"
...
...
@@ -260,7 +254,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3],
s[12:15
], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3],
%[s_res_d
], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -273,7 +267,7 @@
" v_mul_f32 v153, v153, v55
\n
"
" v_mul_f32 v154, v154, v56
\n
"
" v_mul_f32 v155, v155, v57
\n
"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3],
s[12:15
], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3],
%[s_res_d
], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v156, v156
\n
"
" v_mul_f32 v55, v157, v157
\n
"
" v_mul_f32 v56, v158, v158
\n
"
...
...
@@ -294,7 +288,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3],
s[12:15
], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3],
%[s_res_d
], 0 offen offset:3072
\n
"
" s_add_u32 s12, %[s_tile_os_b_half], s12
\n
"
" s_addc_u32 s13, 0, s13
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
...
...
@@ -310,7 +304,7 @@
" v_mul_f32 v158, v158, v56
\n
"
" v_mul_f32 v159, v159, v57
\n
"
" s_waitcnt vmcnt(24)
\n
"
" buffer_load_dwordx4 acc[64:67], %[v_os_b0],
s[12:15
], 0 offen
\n
"
" buffer_load_dwordx4 acc[64:67], %[v_os_b0],
%[s_res_d
], 0 offen
\n
"
" v_mul_f32 v54, v160, v160
\n
"
" v_mul_f32 v55, v161, v161
\n
"
" v_mul_f32 v56, v162, v162
\n
"
...
...
@@ -331,7 +325,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[68:71], %[v_os_b0],
s[12:15
], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[68:71], %[v_os_b0],
%[s_res_d
], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -344,7 +338,7 @@
" v_mul_f32 v161, v161, v55
\n
"
" v_mul_f32 v162, v162, v56
\n
"
" v_mul_f32 v163, v163, v57
\n
"
" buffer_load_dwordx4 acc[72:75], %[v_os_b0],
s[12:15
], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[72:75], %[v_os_b0],
%[s_res_d
], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v164, v164
\n
"
" v_mul_f32 v55, v165, v165
\n
"
" v_mul_f32 v56, v166, v166
\n
"
...
...
@@ -365,7 +359,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[76:79], %[v_os_b0],
s[12:15
], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[76:79], %[v_os_b0],
%[s_res_d
], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -378,7 +372,7 @@
" v_mul_f32 v165, v165, v55
\n
"
" v_mul_f32 v166, v166, v56
\n
"
" v_mul_f32 v167, v167, v57
\n
"
" buffer_load_dwordx4 acc[80:83], %[v_os_b1],
s[12:15
], 0 offen
\n
"
" buffer_load_dwordx4 acc[80:83], %[v_os_b1],
%[s_res_d
], 0 offen
\n
"
" v_mul_f32 v54, v168, v168
\n
"
" v_mul_f32 v55, v169, v169
\n
"
" v_mul_f32 v56, v170, v170
\n
"
...
...
@@ -399,7 +393,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[84:87], %[v_os_b1],
s[12:15
], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[84:87], %[v_os_b1],
%[s_res_d
], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -412,7 +406,7 @@
" v_mul_f32 v169, v169, v55
\n
"
" v_mul_f32 v170, v170, v56
\n
"
" v_mul_f32 v171, v171, v57
\n
"
" buffer_load_dwordx4 acc[88:91], %[v_os_b1],
s[12:15
], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[88:91], %[v_os_b1],
%[s_res_d
], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v172, v172
\n
"
" v_mul_f32 v55, v173, v173
\n
"
" v_mul_f32 v56, v174, v174
\n
"
...
...
@@ -433,7 +427,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[92:95], %[v_os_b1],
s[12:15
], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[92:95], %[v_os_b1],
%[s_res_d
], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -447,7 +441,7 @@
" v_mul_f32 v174, v174, v56
\n
"
" v_mul_f32 v175, v175, v57
\n
"
" s_waitcnt vmcnt(24)
\n
"
" buffer_load_dwordx4 acc[96:99], %[v_os_b2],
s[12:15
], 0 offen
\n
"
" buffer_load_dwordx4 acc[96:99], %[v_os_b2],
%[s_res_d
], 0 offen
\n
"
" v_mul_f32 v54, v176, v176
\n
"
" v_mul_f32 v55, v177, v177
\n
"
" v_mul_f32 v56, v178, v178
\n
"
...
...
@@ -468,7 +462,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2],
s[12:15
], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[100:103], %[v_os_b2],
%[s_res_d
], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -481,7 +475,7 @@
" v_mul_f32 v177, v177, v55
\n
"
" v_mul_f32 v178, v178, v56
\n
"
" v_mul_f32 v179, v179, v57
\n
"
" buffer_load_dwordx4 acc[104:107], %[v_os_b2],
s[12:15
], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[104:107], %[v_os_b2],
%[s_res_d
], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v180, v180
\n
"
" v_mul_f32 v55, v181, v181
\n
"
" v_mul_f32 v56, v182, v182
\n
"
...
...
@@ -502,7 +496,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[108:111], %[v_os_b2],
s[12:15
], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[108:111], %[v_os_b2],
%[s_res_d
], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -515,7 +509,7 @@
" v_mul_f32 v181, v181, v55
\n
"
" v_mul_f32 v182, v182, v56
\n
"
" v_mul_f32 v183, v183, v57
\n
"
" buffer_load_dwordx4 acc[112:115], %[v_os_b3],
s[12:15
], 0 offen
\n
"
" buffer_load_dwordx4 acc[112:115], %[v_os_b3],
%[s_res_d
], 0 offen
\n
"
" v_mul_f32 v54, v184, v184
\n
"
" v_mul_f32 v55, v185, v185
\n
"
" v_mul_f32 v56, v186, v186
\n
"
...
...
@@ -536,7 +530,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[116:119], %[v_os_b3],
s[12:15
], 0 offen offset:1024
\n
"
" buffer_load_dwordx4 acc[116:119], %[v_os_b3],
%[s_res_d
], 0 offen offset:1024
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -549,7 +543,7 @@
" v_mul_f32 v185, v185, v55
\n
"
" v_mul_f32 v186, v186, v56
\n
"
" v_mul_f32 v187, v187, v57
\n
"
" buffer_load_dwordx4 acc[120:123], %[v_os_b3],
s[12:15
], 0 offen offset:2048
\n
"
" buffer_load_dwordx4 acc[120:123], %[v_os_b3],
%[s_res_d
], 0 offen offset:2048
\n
"
" v_mul_f32 v54, v188, v188
\n
"
" v_mul_f32 v55, v189, v189
\n
"
" v_mul_f32 v56, v190, v190
\n
"
...
...
@@ -570,7 +564,7 @@
" v_exp_f32 v55, v55
\n
"
" v_exp_f32 v56, v56
\n
"
" v_exp_f32 v57, v57
\n
"
" buffer_load_dwordx4 acc[124:127], %[v_os_b3],
s[12:15
], 0 offen offset:3072
\n
"
" buffer_load_dwordx4 acc[124:127], %[v_os_b3],
%[s_res_d
], 0 offen offset:3072
\n
"
" v_add_f32 v54, v54, 1.0
\n
"
" v_add_f32 v55, v55, 1.0
\n
"
" v_add_f32 v56, v56, 1.0
\n
"
...
...
@@ -647,7 +641,7 @@
" v_mul_f32 v189, v19, v189 row_newbcast:13
\n
"
" v_mul_f32 v190, v19, v190 row_newbcast:14
\n
"
" v_mul_f32 v191, v19, v191 row_newbcast:15
\n
"
" buffer_load_dword v12, v5,
s[16:19
], 0 offen
\n
"
" buffer_load_dword v12, v5,
%[s_res_dq
], 0 offen
\n
"
" v_mov_b32 v22, 0x358637bd
\n
"
" v_mov_b32 v23, 0x358637bd
\n
"
" v_max3_f32 v22, abs(v128), abs(v129), v22
\n
"
...
...
@@ -945,3 +939,4 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc
View file @
9a46c0e7
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
View file @
9a46c0e7
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
View file @
9a46c0e7
...
...
@@ -102,7 +102,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return
MRepeat
;
}
// TODO: properlly support scatter/gather
// TODO: properlly support scatter/gather
for load only
CK_TILE_DEVICE
auto
GetRowCoords_A
(
index_t
base_offset
)
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
...
...
@@ -116,6 +116,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return
coords
;
}
//for mma and A scale
CK_TILE_DEVICE
auto
GetRowCoords_A_mma
(
index_t
base_offset
)
{
// constexpr index_t KLans = 2;
...
...
@@ -156,6 +157,22 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return
w
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetAScale
(
const
ROW_IDS
row_ids_mma
,
const
AScaleDataType
*
a_scale_ptr
)
{
constexpr
index_t
n_size
=
row_ids_mma
.
size
();
array
<
TopkWeightDataType
,
n_size
>
w
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
auto
row_id
=
row_idx_mma
[
i
]
&
0xffffff
;
auto
itp_k
=
row_idx_mma
[
i
]
>>
24
;
w
.
at
(
i
)
=
sorted_weight_ptr
[
row_id
*
kargs
.
topk
+
itp_k
];
});
return
w
;
}
// TODO: this row id is before shuffle atomic, need use acc distribution
CK_TILE_DEVICE
auto
GetRowCoords_O
(
index_t
base_offset
)
{
...
...
@@ -203,7 +220,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
BlockShape
::
Block_Kr1
);
// intermediate_tile_id * Block_N / (N in W)
auto
row_coords_a
=
GetRowCoords_A
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
row_coords_a_mma
=
GetRowCoords_
A_mma
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
row_coords_a_mma
=
GetRowCoords_
O
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
row_ids_a
=
GetRowID
(
row_coords_a
,
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_token_ids_ptr
));
auto
row_ids_a_mma
=
GetRowID
(
...
...
@@ -221,7 +238,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
//addr in fact
auto
a_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
(
row_ids_a
[
i
])
*
kargs
.
stride_token
+
return
(
(
row_ids_a
[
i
])
&
0xffffff
)
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
number
<
row_ids_a
.
size
()
>
{});
...
...
@@ -231,9 +248,11 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
//////aq
auto
aq_win
=
[
&
]()
{
const
AScaleDataType
*
aq_ptr
=
reinterpret_cast
<
const
AScaleDataType
*>
(
kargs
.
a_scale_ptr
);
auto
aq_view_
=
make_naive_tensor_view
_packed
<
address_space_enum
::
global
>
(
auto
aq_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
aq_ptr
,
make_tuple
(
kargs
.
num_tokens
*
kargs
.
topk
),
make_tuple
(
1
),
number
<
1
>
{},
number
<
1
>
{});
return
aq_view_
;
...
...
@@ -272,9 +291,11 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
static_cast
<
long_index_t
>
(
expert_id
)
*
g_scale_expert_stride_0
+
intermediate_tile_id
*
BlockShape
::
Block_N0
;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto
gq_view_
=
make_naive_tensor_view
_packed
<
address_space_enum
::
global
>
(
auto
gq_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
gq_ptr
,
make_tuple
(
shared_intermediate_size_1
),
make_tuple
(
1
),
number
<
1
>
{},
number
<
1
>
{});
return
gq_view_
;
...
...
@@ -287,9 +308,11 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
static_cast
<
long_index_t
>
(
expert_id
)
*
smq_scale_expert_stride_0
+
intermediate_tile_id
*
BlockShape
::
Block_N0
;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto
smq_view_
=
make_naive_tensor_view
_packed
<
address_space_enum
::
global
>
(
auto
smq_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
smq_ptr
,
make_tuple
(
shared_intermediate_size_1
),
make_tuple
(
1
),
number
<
1
>
{},
number
<
1
>
{});
return
smq_view_
;
...
...
@@ -393,12 +416,14 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto
row_coords_o
=
GetRowCoords_O
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
w_scale
=
GetWeightScale
(
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
auto
a_scale
=
GetAScale
(
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
a_scale_ptr
));
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
// auto acc_0= uk_0(
uk_0
(
row_ids_a_mma
,
//fake token id, 2D index for X scale
a
q_res
,
a
_scale
,
dq_res
,
gq_res
,
smq_res
,
...
...
@@ -430,7 +455,8 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
// block_sync_lds();
auto
uk_1
=
Policy
::
template
GetUK_1
<
Problem
>();
uk_1
(
d_res
,
uk_1
(
dq_res
,
d_res
,
d_coords
,
o_res
,
o_coords
,
...
...
script/cmake-ck-dev.sh
View file @
9a46c0e7
...
...
@@ -17,7 +17,7 @@ fi
cmake
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker"
\
-D
CMAKE_CXX_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker
-save-temps=
$PWD
"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
$GPU_TARGETS
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment