Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f549173b
Commit
f549173b
authored
Jan 01, 2025
by
shengnxu
Browse files
simple gemm2 for gemm1 debuggging
parent
811b75d3
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
2002 additions
and
603 deletions
+2002
-603
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
...ps/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
+6
-1
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
...flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
+3
-0
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8.inc
...k/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8.inc
+1948
-578
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
...lock/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
+20
-17
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
+25
-7
No files found.
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
View file @
f549173b
...
...
@@ -245,12 +245,13 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template
<
typename
DQRes
,
typename
GQRes
,
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
>
template
<
typename
DQRes
,
typename
GQRes
,
typename
SMQRes
,
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
>
CK_TILE_DEVICE
auto
operator
()(
index_t
row_ids_a_
,
const
DQes
&
res_aq
const
DQes
&
res_dq
,
const
GQRes
&
res_gq
,
const
SMQRes
&
res_smq
,
const
Res
&
res_a
,
const
ACoords
&
cached_coords_a
,
const
BRes
&
res_b
,
...
...
@@ -405,6 +406,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
[
s_res_gq1
]
"s"
(
res_gq
[
1
]),
[
s_res_gq2
]
"s"
(
res_gq
[
2
]),
[
s_res_gq3
]
"s"
(
res_gq
[
3
]),
[
s_res_smq0
]
"s"
(
res_smq
[
0
]),
[
s_res_smq1
]
"s"
(
res_smq
[
1
]),
[
s_res_smq2
]
"s"
(
res_smq
[
2
]),
[
s_res_smq3
]
"s"
(
res_smq
[
3
]),
[
s_res_a0
]
"s"
(
res_a
[
0
]),
[
s_res_a1
]
"s"
(
res_a
[
1
]),
[
s_res_a2
]
"s"
(
res_a
[
2
]),
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
View file @
f549173b
...
...
@@ -92,6 +92,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
n
,
// loop along n dim
const
ScaleTensor
&
scale_
,
index_t
tile_offset_dq
,
index_t
tile_offset_b
,
// stride b is fixed to blockKr * blockW, but still can adjust
index_t
tile_offset_half_b
,
//splited load alone K in to 2 part
index_t
tile_offset_o
)
...
...
@@ -102,6 +103,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
const
index_t
tile_stride_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
const
index_t
tile_offset_half_b_bytes
=
tile_offset_half_b
*
sizeof
(
BDataType
);
const
index_t
tile_stride_o_bytes
=
tile_offset_o
*
sizeof
(
ODataType
);
const
index_t
tile_stride_dq_bytes
=
tile_offset_dq
*
sizeof
(
DScaleDataType
);
static_assert
(
ScaleTensor
::
size
()
==
2
);
float
s0
=
scale_
[
number
<
0
>
{}];
...
...
@@ -244,6 +246,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_b_half
]
"s"
(
tile_offset_half_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
s_tile_os_dq
]
"s"
(
tile_stride_dq_bytes
),
[
scale_0
]
"v"
(
s0
),
[
scale_1
]
"v"
(
s1
),
[
v_nan_lo
]
"v"
(
nan_lo
),
...
...
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8.inc
View file @
f549173b
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
View file @
f549173b
...
...
@@ -12,7 +12,26 @@
" v_mul_f32 a[2], v17, a[2] row_newbcast:14
\n
"
\
" v_mul_f32 a[3], v17, a[3] row_newbcast:15
\n
"
\
"s_mov_b32 s16, %[s_res_dq0]
\n
"
"s_mov_b32 s17, %[s_res_dq1]
\n
"
"s_mov_b32 s18, %[s_res_dq2]
\n
"
"s_mov_b32 s19, %[s_res_dq3]
\n
"
"s_mov_b32 s32, %[s_res_gq0]
\n
"
"s_mov_b32 s33, %[s_res_gq1]
\n
"
"s_mov_b32 s34, %[s_res_gq2]
\n
"
"s_mov_b32 s35, %[s_res_gq3]
\n
"
"s_mov_b32 s36, %[s_res_smq0]
\n
"
"s_mov_b32 s37, %[s_res_smq1]
\n
"
"s_mov_b32 s38, %[s_res_smq2]
\n
"
"s_mov_b32 s39, %[s_res_smq3]
\n
"
"s_mov_b32 s20, %[s_res_a0]
\n
"
"s_mov_b32 s21, %[s_res_a1]
\n
"
"s_mov_b32 s22, %[s_res_a2]
\n
"
"s_mov_b32 s23, %[s_res_a3]
\n
"
"s_mov_b32 s24, %[s_res_b0]
\n
"
"s_mov_b32 s25, %[s_res_b1]
\n
"
"s_mov_b32 s26, %[s_res_b2]
\n
"
"s_mov_b32 s27, %[s_res_b3]
\n
"
//////////GQ/DQ/GsmQ_addr///////////////
//expert weight addr no need
...
...
@@ -84,22 +103,6 @@
" buffer_load_dword v20, v8, s[40:43], 0 offen
\n
"
" buffer_load_dword v21, v9, s[40:43], 0 offen
\n
"
"s_mov_b32 s16, %[s_res_dq0]
\n
"
"s_mov_b32 s17, %[s_res_dq1]
\n
"
"s_mov_b32 s18, %[s_res_dq2]
\n
"
"s_mov_b32 s19, %[s_res_dq3]
\n
"
"s_mov_b32 s32, %[s_res_gq0]
\n
"
"s_mov_b32 s33, %[s_res_gq1]
\n
"
"s_mov_b32 s34, %[s_res_gq2]
\n
"
"s_mov_b32 s35, %[s_res_gq3]
\n
"
"s_mov_b32 s20, %[s_res_a0]
\n
"
"s_mov_b32 s21, %[s_res_a1]
\n
"
"s_mov_b32 s22, %[s_res_a2]
\n
"
"s_mov_b32 s23, %[s_res_a3]
\n
"
"s_mov_b32 s24, %[s_res_b0]
\n
"
"s_mov_b32 s25, %[s_res_b1]
\n
"
"s_mov_b32 s26, %[s_res_b2]
\n
"
"s_mov_b32 s27, %[s_res_b3]
\n
"
" s_mov_b32 s80, 0
\n
"
//---------------------v26-33 no need
// "s_nop 4\n"
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
View file @
f549173b
...
...
@@ -180,6 +180,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
/////////////
index_t
a_scale_expert_stride_0
=
kargs
.
hidden_size
;
index_t
g_scale_expert_stride_0
=
shared_intermediate_size_0
;
index_t
smq_scale_expert_stride_0
=
shared_intermediate_size_0
;
index_t
d_scale_expert_stride_1
=
kargs
.
hidden_size
;
// nr*kr*w
index_t
interm_idx_nr0
=
__builtin_amdgcn_readfirstlane
(
...
...
@@ -244,12 +245,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
number
<
decltype
(
g_win
)
::
NumAccess_NonLinear
>
{});
//////gq
auto
gq_win
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GScaleDataType
*>
(
kargs
.
g_scale_ptr
)
+
const
G
Scale
DataType
*
g
q
_ptr
=
reinterpret_cast
<
const
GScaleDataType
*>
(
kargs
.
g_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
g_scale_expert_stride_0
+
intermediate_tile_id
*
BlockShape
::
Block_N0
;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
auto
g
q
_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g
q
_ptr
,
make_tuple
(
shared_intermediate_size_1
),
number
<
1
>
{});
...
...
@@ -257,7 +258,22 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
}();
auto
gq_res
=
gq_win
.
get_buffer_view
().
cached_buf_res_
;
////
////smQ
auto
smq_win
=
[
&
]()
{
const
YSmoothScaleDataType
*
smq_ptr
=
reinterpret_cast
<
const
YSmoothScaleDataType
*>
(
kargs
.
y_smooth_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
smq_scale_expert_stride_0
+
intermediate_tile_id
*
BlockShape
::
Block_N0
;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto
smq_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
smq_ptr
,
make_tuple
(
shared_intermediate_size_1
),
number
<
1
>
{});
return
smq_view_
;
}();
auto
smq_res
=
smq_win
.
get_buffer_view
().
cached_buf_res_
;
/////////////////////
const
auto
d_win
=
[
&
]()
{
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
...
...
@@ -284,8 +300,9 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto
d_res
=
d_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
//////gq
auto
dq_win
=
[
&
]()
{
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr) + static_cast<long_index_t>(expert_id) * d_scale_expert_stride_0;
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GScaleDataType
*>
(
kargs
.
d_scale_ptr
)
//remember to add expert_id as expert_idx
const
DScaleDataType
*
g_ptr
=
reinterpret_cast
<
const
DScaleDataType
*>
(
kargs
.
d_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
d_scale_expert_stride_1
;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr)//remember to add expert_id as expert_idx
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
kargs
.
hidden_size
),
...
...
@@ -368,7 +385,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
kargs
.
hidden_size
,
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
// sweep_tile(
// acc_0,
...
...
@@ -396,6 +413,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
smem
,
kargs
.
hidden_size
,
// total n number
w_scale
,
BlockShape
::
Block_N1
,
shared_intermediate_size_1
*
Block_N1
-
kr_1
*
BlockShape
::
Block_W1
,
// along N
kr_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_N1
);
// along N
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment