Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d0c80b12
Commit
d0c80b12
authored
Jan 03, 2025
by
shengnxu
Browse files
fix more issues, current status, inline asm using more register than available
parent
a759277d
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
143 additions
and
68 deletions
+143
-68
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
+1
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
...ile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
+17
-2
example/ck_tile/15_fused_moe/instances/fused_moegemm_int8_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_int8_m32.cpp
+14
-0
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+17
-9
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+1
-2
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
...ps/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
+7
-3
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
...flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
+1
-11
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
...lock/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
+11
-12
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
...sed_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
+21
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
+53
-29
No files found.
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
View file @
d0c80b12
...
...
@@ -37,3 +37,4 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
// clang-format on
return
r
;
}
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
View file @
d0c80b12
...
...
@@ -5,11 +5,26 @@
#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include "fused_moegemm_api.cpp"
#include <iostream>
template
<
ck_tile
::
index_t
...
Is
>
using
S
=
ck_tile
::
sequence
<
Is
...
>
;
template
<
typename
dtype
,
typename
problem
>
struct
PipelineDispatch
;
template
<
typename
problem
>
struct
PipelineDispatch
<
ck_tile
::
int8_t
,
problem
>
{
using
type
=
ck_tile
::
FusedMoeGemmPipeline_FlatmmUk_int8
<
problem
>
;
};
template
<
typename
problem
>
struct
PipelineDispatch
<
ck_tile
::
bf16_t
,
problem
>
{
using
type
=
ck_tile
::
FusedMoeGemmPipeline_FlatmmUk
<
problem
>
;
};
template
<
typename
problem
>
struct
PipelineDispatch
<
ck_tile
::
fp16_t
,
problem
>
{
using
type
=
ck_tile
::
FusedMoeGemmPipeline_FlatmmUk
<
problem
>
;
};
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
template
<
typename
Ts_
>
float
fused_moegemm_
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
)
...
...
@@ -38,8 +53,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
f_traits
>
;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_FlatmmUk_int8
<
>
;
using
f_pipeline
=
typename
PipelineDispatch
<
typename
Ts_
::
ADataType
,
f_problem
>::
type
;
using
f_partitioner
=
ck_tile
::
FusedMoeGemmTilePartitioner_Linear
<
f_shape
>
;
using
f_kernel
=
ck_tile
::
FusedMoeGemmKernel
<
f_partitioner
,
f_pipeline
,
void
>
;
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_int8_m32.cpp
0 → 100644
View file @
d0c80b12
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
int8_t
,
ck_tile
::
int8_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
256
,
256
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
64
>
,
1
,
1
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
example/ck_tile/15_fused_moe/main.cpp
View file @
d0c80b12
...
...
@@ -204,7 +204,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
GDataType
=
typename
TypeConfig
::
GDataType
;
using
DDataType
=
typename
TypeConfig
::
DDataType
;
using
AccDataType
=
typename
TypeConfig
::
AccDataType
;
//
using AccDataType = typename TypeConfig::AccDataType;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
using
AScaleDataType
=
typename
TypeConfig
::
AScaleDataType
;
using
GScaleDataType
=
typename
TypeConfig
::
GScaleDataType
;
...
...
@@ -218,12 +218,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
GDataType
>
g_host
({
experts
,
shared_intermediate_size_0
,
hidden_size
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
({
experts
,
hidden_size
,
shared_intermediate_size_1
});
ck_tile
::
HostTensor
<
ODataType
>
o_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
if
(
fused_quant
==
1
)
{
//
if (fused_quant == 1)
//
{
ck_tile
::
HostTensor
<
AScaleDataType
>
sa_host
({
tokens
,
topk
});
}
else
{
ck_tile
::
HostTensor
<
AScaleDataType
>
sa_host
({
tokens
});
}
//
} else{
//
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
//
}
ck_tile
::
HostTensor
<
GScaleDataType
>
sg_host
({
experts
,
shared_intermediate_size_0
});
ck_tile
::
HostTensor
<
DScaleDataType
>
sd_host
({
experts
,
shared_intermediate_size_1
});
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
experts
,
shared_intermediate_size_1
});
// smooth-quant
...
...
@@ -425,7 +425,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts
,
block_m
);
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
ck_tile
::
reference_fused_moe
<
float
,
ck_tile
::
element_wise
::
Gelu
>
(
a_host
,
g_host
,
d_host
,
...
...
@@ -535,7 +535,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
{
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
ck_tile
::
reference_fused_moe
<
float
,
ck_tile
::
element_wise
::
Gelu
>
(
a_host
,
g_host
,
d_host
,
...
...
@@ -555,7 +555,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
hidden_size
,
shared_intermediate_size_0
,
topk
,
gate_only
);
gate_only
,
1
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
// o_dev.savetxt("gpu-out.txt", "float");
...
...
@@ -604,6 +605,13 @@ int main(int argc, char* argv[])
?
0
:
-
2
;
}
else
if
(
prec_i
==
"int8"
&&
prec_w
==
"int8"
&&
prec_o
==
"bf16"
&&
prec_kw
==
"fp32"
)
{
return
run
<
ck_tile
::
int8_t
,
ck_tile
::
int8_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
include/ck_tile/host/reference/reference_fused_moe.hpp
View file @
d0c80b12
...
...
@@ -107,10 +107,9 @@ void reference_fused_moe(
return
;
ck_tile
::
index_t
i_expert
=
sorted_expert_ids_host
.
mData
[
i_tile
];
ck_tile
::
index_t
i_token
=
sorted_token_ids_host
.
mData
[
i_flatten
];
ck_tile
::
index_t
i_weight_idx
;
ck_tile
::
index_t
i_weight_idx
=
i_token
>>
24
;
if
(
fquant
==
1
)
{
i_weight_idx
=
i_token
>>
24
;
i_token
=
i_token
&
0xffffff
;
}
if
(
i_token
>=
tokens
)
...
...
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
View file @
d0c80b12
...
...
@@ -245,9 +245,9 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template
<
typename
AQRes
,
typename
DQRes
,
typename
GQRes
,
typename
SMQRes
,
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
>
template
<
typename
AToken_id
,
typename
AQRes
,
typename
DQRes
,
typename
GQRes
,
typename
SMQRes
,
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
>
CK_TILE_DEVICE
auto
operator
()(
index_t
row_ids_a_
,
operator
()(
const
AToken_id
&
row_ids_a_
,
const
AQRes
&
res_aq
,
const
DQRes
&
res_dq
,
const
GQRes
&
res_gq
,
...
...
@@ -263,6 +263,7 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
{
static_assert
(
ACoords
::
size
()
==
Block_M
*
Block_K
/
BlockSize
/
4
/*2x per dword*/
);
// 8
static_assert
(
BCoords
::
size
()
==
Repeat_N
);
static_assert
(
AToken_id
::
size
()
==
Repeat_M
);
auto
a_sst
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
...
...
@@ -371,6 +372,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
register
int
v_z62
asm
(
"v190"
)
=
0
;
register
int
v_z63
asm
(
"v191"
)
=
0
;
index_t
temp0
=
static_cast
<
index_t
>
(
row_ids_a_
[
number
<
0
>
{}]);
index_t
temp1
=
static_cast
<
index_t
>
(
row_ids_a_
[
number
<
1
>
{}]);
// B nr->kr
#pragma clang diagnostic push
...
...
@@ -397,7 +400,8 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
// [v_acc_13]"+v"(v_acc[13]),
// [v_acc_14]"+v"(v_acc[14]),
// [v_acc_15]"+v"(v_acc[15]),
[
v_token_id
]
"+v"
(
row_ids_a_
),
[
v_token_id0
]
"+v"
(
temp0
),
[
v_token_id1
]
"+v"
(
temp1
),
[
s_mem_
]
"+r"
(
smem
)
:
[
s_res_aq0
]
"s"
(
res_aq
[
0
]),
[
s_res_aq1
]
"s"
(
res_aq
[
1
]),
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
View file @
d0c80b12
...
...
@@ -98,7 +98,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
index_t
tile_offset_half_b
,
//splited load alone K in to 2 part
index_t
tile_offset_o
)
{
static_assert
(
BCoords
::
size
()
==
8
);
// 8
static_assert
(
BCoords
::
size
()
==
4
);
// 8
static_assert
(
OCoords
::
size
()
==
8
);
const
index_t
tile_stride_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
...
...
@@ -238,11 +238,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_b_half
]
"s"
(
tile_offset_half_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
...
...
@@ -393,11 +388,6 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_b_half
]
"s"
(
tile_offset_half_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
...
...
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
View file @
d0c80b12
...
...
@@ -78,18 +78,17 @@
" v_mov_b32 v53, 0x00007fff
\n
"
" s_waitcnt 0x0000
\n
"
";----------------------------------------------
\n
"
" v_mov_b32 %[v_token_id], %[v_token_id]
\n
"
" v_lshrrev_b32 v54, 24, %[v_token_id]
\n
"
" v_mul_i32_i24 v54, s66, v54
\n
"
" v_and_b32 v55, 0x00ffffff, %[v_token_id]
\n
"
" v_add_u32 %[v_token_id], v54, v55
\n
"
" v_lshrrev_b32 v54, 24, v7
\n
"
" v_mul_i32_i24 v54, s66, v54
\n
"
" v_and_b32 v55, 0x00ffffff, v7
\n
"
" v_lshrrev_b32 v54, 24, %[v_token_id0]
\n
"
" v_mul_i32_i24 v54, s66, v54
\n
"
" v_and_b32 v55, 0x00ffffff, %[v_token_id0]
\n
"
" v_add_u32 v6, v54, v55
\n
"
" v_lshrrev_b32 v54, 24, %[v_token_id1]
\n
"
" v_mul_i32_i24 v54, s66, v54
\n
"
" v_and_b32 v55, 0x00ffffff, %[v_token_id1]
\n
"
" v_add_u32 v7, v54, v55
\n
"
" v_lshlrev_b32
%[v_token_id], 2, %[v_token_id]
\n
"
" v_lshlrev_b32
v6, 2, v6
\n
"
" v_lshlrev_b32 v7, 2, v7
\n
"
" buffer_load_dword v14,
%[v_token_id]
, s[28:31], 0 offen
\n
"
" buffer_load_dword v14,
v6
, s[28:31], 0 offen
\n
"
" buffer_load_dword v15, v7, s[28:31], 0 offen
\n
"
" buffer_load_dword v16, v10, s[32:35], 0 offen
\n
"
" buffer_load_dword v17, v11, s[32:35], 0 offen
\n
"
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
View file @
d0c80b12
...
...
@@ -804,6 +804,13 @@ struct FusedMoeGemmPipelineFlatmmPolicy
{
return
Flatmm_32x512x128_1x4x1_16x16x32_FP16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
int8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
int8_t
>
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
512
&&
S_
::
Block_K0
==
256
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
64
)
{
return
Flatmm_32x512x256_1x4x1_16x16x64_int8
{};
}
}
template
<
typename
Problem
>
...
...
@@ -851,6 +858,20 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return
FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
int8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
int8_t
>
&&
S_
::
Block_M1
==
32
&&
S_
::
Block_N1
==
256
&&
S_
::
Block_K1
==
512
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
64
&&
T_
::
PipeInterleave
==
false
)
{
return
FlatmmSn_32x256x512_1x4x1_16x16x64_int8
{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
else
{
return
FlatmmSn_32x256x512_1x4x1_16x16x64_int8
{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
View file @
d0c80b12
...
...
@@ -116,7 +116,19 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
return
coords
;
}
CK_TILE_DEVICE
auto
GetRowCoords_A_mma
(
index_t
base_offset
)
{
// constexpr index_t KLans = 2;
constexpr
index_t
MLans
=
16
;
constexpr
index_t
MRepeat
=
BlockShape
::
Repeat_M1
;
auto
base_coord
=
threadIdx
.
x
%
MLans
+
base_offset
;
array
<
index_t
,
MRepeat
>
coords
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLans
;
});
return
coords
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetRowID
(
const
ROW_COORDS
coords
,
const
IndexDataType
*
sorted_token_ids_ptr
)
{
...
...
@@ -178,7 +190,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
shared_intermediate_size_1
*
kargs
.
hidden_size
;
/////////////
index_t
a_scale_expert_stride_0
=
kargs
.
hidden_size
;
index_t
g_scale_expert_stride_0
=
shared_intermediate_size_0
;
index_t
smq_scale_expert_stride_0
=
shared_intermediate_size_0
;
index_t
d_scale_expert_stride_1
=
kargs
.
hidden_size
;
...
...
@@ -192,13 +203,25 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
BlockShape
::
Block_Kr1
);
// intermediate_tile_id * Block_N / (N in W)
auto
row_coords_a
=
GetRowCoords_A
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
row_coords_a_mma
=
GetRowCoords_A_mma
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
row_ids_a
=
GetRowID
(
row_coords_a
,
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_token_ids_ptr
));
auto
token_id
=
row_ids_a
&
0xffffff
;
auto
row_ids_a_mma
=
GetRowID
(
row_coords_a_mma
,
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_token_ids_ptr
));
auto
token_id
=
generate_tuple
(
[
&
](
auto
i
)
{
return
(
row_ids_a
[
i
])
&
0xffffff
;
},
number
<
row_ids_a
.
size
()
>
{});
// auto token_id_mma = generate_tuple(
// [&](auto i) {
// return (row_ids_a_mma[i]) &0xffffff;
// },
// number<row_ids_a_mma.size()>{});
//addr in fact
auto
a_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
(
token_id
[
i
])
*
kargs
.
stride_token
+
return
(
row_ids_a
[
i
])
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
number
<
row_ids_a
.
size
()
>
{});
...
...
@@ -208,7 +231,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
//////aq
auto
aq_win
=
[
&
]()
{
const
AScaleDataType
*
aq_ptr
=
reinterpret_cast
<
const
AScaleDataType
*>
(
kargs
.
a_scale_ptr
);
auto
aq_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
aq_view_
=
make_naive_tensor_view
_packed
<
address_space_enum
::
global
>
(
aq_ptr
,
make_tuple
(
kargs
.
num_tokens
*
kargs
.
topk
),
number
<
1
>
{});
...
...
@@ -249,7 +272,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
static_cast
<
long_index_t
>
(
expert_id
)
*
g_scale_expert_stride_0
+
intermediate_tile_id
*
BlockShape
::
Block_N0
;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto
gq_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
gq_view_
=
make_naive_tensor_view
_packed
<
address_space_enum
::
global
>
(
gq_ptr
,
make_tuple
(
shared_intermediate_size_1
),
number
<
1
>
{});
...
...
@@ -264,7 +287,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
static_cast
<
long_index_t
>
(
expert_id
)
*
smq_scale_expert_stride_0
+
intermediate_tile_id
*
BlockShape
::
Block_N0
;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto
smq_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
smq_view_
=
make_naive_tensor_view
_packed
<
address_space_enum
::
global
>
(
smq_ptr
,
make_tuple
(
shared_intermediate_size_1
),
number
<
1
>
{});
...
...
@@ -303,7 +326,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
const
DScaleDataType
*
g_ptr
=
reinterpret_cast
<
const
DScaleDataType
*>
(
kargs
.
d_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
d_scale_expert_stride_1
;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr)//remember to add expert_id as expert_idx
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
g_view_
=
make_naive_tensor_view
_packed
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
kargs
.
hidden_size
),
number
<
1
>
{});
...
...
@@ -323,12 +346,12 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
auto
d_coords
=
[
&
]()
{
constexpr
index_t
Nr_
=
4
;
constexpr
index_t
Nw_
=
4
;
constexpr
index_t
Kr0_
=
2
;
//no more need in int8, method changed, this will be handed in res_s
constexpr
index_t
Kr1_
=
4
;
constexpr
index_t
Kl_
=
4
;
//
constexpr index_t Kr0_ = 2;//no more need in int8, method changed, this will be handed in res_s
//
constexpr index_t Kr1_ = 4;
//
constexpr index_t Kl_ = 4;
constexpr
index_t
Nl_
=
16
;
constexpr
index_t
Kv_
=
16
;
constexpr
index_t
W_
=
Kl_
*
Nl_
*
Kv_
;
//
constexpr index_t W_ = Kl_ * Nl_ * Kv_;
//constexpr index_t num_offsets_ = Nr_ * Kr0_;
constexpr
index_t
num_offsets_
=
Nr_
;
index_t
base_os_
=
(
threadIdx
.
x
%
64
)
*
Kv_
+
(
threadIdx
.
x
/
64
)
*
...
...
@@ -351,18 +374,18 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
number
<
row_ids_a
.
size
()
>
{});
auto
o_flags
=
generate_tuple
([
&
](
auto
i
)
{
return
cmp_lt_to_exec
(
token_id
,
kargs
.
num_tokens
);
},
generate_tuple
([
&
](
auto
i
)
{
return
cmp_lt_to_exec
(
token_id
[
i
]
,
kargs
.
num_tokens
);
},
number
<
row_ids_a
.
size
()
>
{});
auto
bridge_sst_win
=
[
&
]()
{
constexpr
auto
desc_
=
Policy
::
template
MakeBridgeLdsStoreForUKDesc
<
Problem
>();
constexpr
auto
dist_
=
Policy
::
template
GetUK_0
<
Problem
>().
MakeCBlockDist
();
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
YDataType
*>
(
smem
),
desc_
),
desc_
.
get_lengths
(),
{
0
,
0
},
dist_
);
}();
//
auto bridge_sst_win = [&]() {
//
constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
//
constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
//
return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
//
reinterpret_cast<YDataType*>(smem), desc_),
//
desc_.get_lengths(),
//
{0, 0},
//
dist_);
//
}();
auto
o_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ODataType
));
...
...
@@ -372,12 +395,13 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
acc_0
=
uk_0
(
row_ids_a
,
//fake token id, 2D index for X scale
// auto acc_0= uk_0(
uk_0
(
row_ids_a_mma
,
//fake token id, 2D index for X scale
aq_res
,
gq_res
,
gq_res
,
dq_res
,
gq_res
,
smq_res
,
a_res
,
a_coords
,
g_res
,
...
...
@@ -415,7 +439,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
kargs
.
hidden_size
,
// total n number
w_scale
,
BlockShape
::
Block_N1
,
shared_intermediate_size_1
*
Block_N1
-
kr_1
*
BlockShape
::
Block_W1
,
// along N
shared_intermediate_size_1
*
BlockShape
::
Block_N1
-
kr_1
*
BlockShape
::
Block_W1
,
// along N
kr_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_N1
);
// along N
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment