Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cf646183
Commit
cf646183
authored
Nov 06, 2024
by
carlushuang
Browse files
compile OK
parent
70fa98ad
Changes
20
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
504 additions
and
430 deletions
+504
-430
example/ck_tile/15_fused_moe/fused_moegemm.hpp
example/ck_tile/15_fused_moe/fused_moegemm.hpp
+4
-4
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
+8
-16
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
...ile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
+28
-20
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
..._tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
+24
-20
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
+14
-0
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+34
-61
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+1
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+1
-19
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+18
-0
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+26
-0
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+1
-1
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+58
-0
include/ck_tile/core/tensor/update_tile.hpp
include/ck_tile/core/tensor/update_tile.hpp
+1
-0
include/ck_tile/host/device_memory.hpp
include/ck_tile/host/device_memory.hpp
+9
-4
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+54
-51
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
+11
-10
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
...e/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
+4
-4
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
.../ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
+103
-81
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
...sed_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
+100
-134
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
+5
-5
No files found.
example/ck_tile/15_fused_moe/fused_moegemm.hpp
View file @
cf646183
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/
layernorm2d
.hpp"
#include "ck_tile/ops/
fused_moe
.hpp"
#include <string>
#include <string>
// this is only a convenient structure for creating an example
// this is only a convenient structure for creating an example
...
@@ -14,7 +14,7 @@ template <typename I, typename W, typename O, typename ST, typename SW, typename
...
@@ -14,7 +14,7 @@ template <typename I, typename W, typename O, typename ST, typename SW, typename
struct
FusedMoeGemmTypeConfig
;
struct
FusedMoeGemmTypeConfig
;
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
;
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
{
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
ADataType
=
ck_tile
::
bf16_t
;
using
GDataType
=
ck_tile
::
bf16_t
;
using
GDataType
=
ck_tile
::
bf16_t
;
...
@@ -30,7 +30,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t,
...
@@ -30,7 +30,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t,
};
};
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
int8_t
,
ck_tile
::
int8_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
;
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
int8_t
,
ck_tile
::
int8_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
{
{
using
ADataType
=
ck_tile
::
int8_t
;
using
ADataType
=
ck_tile
::
int8_t
;
using
GDataType
=
ck_tile
::
int8_t
;
using
GDataType
=
ck_tile
::
int8_t
;
...
@@ -46,7 +46,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t,
...
@@ -46,7 +46,7 @@ struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t,
};
};
// runtime args
// runtime args
struct
fused_moegemm_args
:
public
ck_tile
::
Layernorm2dFwd
HostArgs
struct
fused_moegemm_args
:
public
ck_tile
::
FusedMoeGemm
HostArgs
{
{
};
};
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
View file @
cf646183
...
@@ -3,33 +3,25 @@
...
@@ -3,33 +3,25 @@
#include <ck_tile/core.hpp>
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
// Note: this internal API only declare, not define here, otherwise will block `make -j`
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template
<
typename
Traits_
>
template
<
typename
Traits_
>
float
fused_moegemm_
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
float
fused_moegemm_
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
<
ck_tile
::
index_t
...
Is
>
using
S
=
ck_tile
::
sequence
<
Is
...
>
;
float
fused_moegemm
(
fused_moegemm_traits
t
,
fused_moegemm_args
a
,
const
ck_tile
::
stream_config
&
s
)
float
fused_moegemm
(
fused_moegemm_traits
t
,
fused_moegemm_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
{
template
<
ck_tile
::
index_t
...
Is
>
// clang-format off
using
S
=
ck_tile
::
sequence
<
Is
...
>
;
float
r
=
-
1
;
float
r
=
-
1
;
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
block_m
==
32
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
gate_only
==
1
)
{
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
16
>
,
1
,
0
>
;
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
4
,
1
,
1
>
,
S
<
32
,
32
,
16
>
,
1
,
0
>
;
fused_moegemm_
<
t_
>
(
s
,
a
);
fused_moegemm_
<
t_
>
(
s
,
a
);
}
}
// clang-format on
return
r
;
return
r
;
}
}
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
View file @
cf646183
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include <iostream>
template
<
ck_tile
::
index_t
...
Is
>
using
S
=
ck_tile
::
sequence
<
Is
...
>
;
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
template
<
typename
Ts_
>
template
<
typename
Ts_
>
float
fused_moegemm_
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
)
float
fused_moegemm_
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
)
{
{
using
f_traits
=
ck_tile
::
FusedMoeGemmTraits
<
Ts_
::
GateOnly
,
Ts_
::
FusedQuant
==
1
,
1
/*atomic*/
>
;
using
f_traits
=
ck_tile
::
FusedMoeGemmTraits
<
Ts_
::
GateOnly
,
Ts_
::
FusedQuant
==
1
,
1
/*atomic*/
>
;
using
f_shape
=
ck_tile
::
FusedMoeGemmShape
<
typename
Ts_
::
BlockTile_0
,
using
f_shape
=
ck_tile
::
FusedMoeGemmShape
<
typename
Ts_
::
BlockTile_0
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts
::
WarpTile_0
,
typename
Ts
_
::
WarpTile_0
,
typename
Ts_
::
BlockTile_1
,
typename
Ts_
::
BlockTile_1
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts
::
WarpTile_0
>
;
typename
Ts_
::
WarpTile_0
>
;
using
f_problem
=
ck_tile
::
FusedMoeGemmPipelineProblem
<
typename
Ts_
::
ADataType
,
using
f_problem
=
typename
Ts_
::
GDataType
,
ck_tile
::
FusedMoeGemmPipelineProblem
<
typename
Ts_
::
ADataType
,
typename
Ts_
::
DDataType
,
typename
Ts_
::
GDataType
,
typename
Ts_
::
AccDataType
,
typename
Ts_
::
DDataType
,
typename
Ts_
::
ODataType
,
typename
Ts_
::
AccDataType
,
typename
Ts_
::
AScaleDataType
,
typename
Ts_
::
ODataType
,
typename
Ts_
::
GScaleDataType
,
typename
Ts_
::
AScaleDataType
,
typename
Ts_
::
DScaleDataType
,
typename
Ts_
::
GScaleDataType
,
typename
Ts_
::
YSmoothScaleDataType
,
typename
Ts_
::
DScaleDataType
,
typename
Ts_
::
TopkWeightDataType
,
typename
Ts_
::
YSmoothScaleDataType
,
typename
Ts_
::
IndexDataType
,
typename
Ts_
::
TopkWeightDataType
,
ck_tile
::
Gelu
,
// TODO: hardcoded
typename
Ts_
::
IndexDataType
,
f_shape
,
ck_tile
::
element_wise
::
Gelu
,
// TODO: hardcoded
f_traits
>
f_shape
,
f_traits
>
;
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_Flatmm
<
f_problem
>
;
using
f_partitioner
=
ck_tile
::
FusedMoeGemmTilePartitioner_Linear
<
f_shape
>
;
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_Flatmm
<
f_problem
>
;
using
f_kernel
=
ck_tile
::
FusedMoeGemmKernel
<
f_partitioner
,
f_pipeline
,
void
>
;
using
f_partitioner
=
ck_tile
::
FusedMoeGemmTilePartitioner_Linear
<
f_shape
>
;
using
f_kernel
=
ck_tile
::
FusedMoeGemmKernel
<
f_partitioner
,
f_pipeline
,
void
>
;
const
dim3
grids
=
f_kernel
::
GridSize
(
a
);
const
dim3
grids
=
f_kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
f_kernel
::
BlockSize
();
constexpr
dim3
blocks
=
f_kernel
::
BlockSize
();
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
View file @
cf646183
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
#include <ck_tile/core.hpp>
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
...
@@ -20,30 +22,32 @@ struct fmoe_ // traits, ugly name, only used for internal
...
@@ -20,30 +22,32 @@ struct fmoe_ // traits, ugly name, only used for internal
{
{
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
remove_cvref_t
<
typename
TypeConfig
::
ADataType
>
;
using
ADataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
ADataType
>
;
using
GDataType
=
remove_cvref_t
<
typename
TypeConfig
::
GDataType
>
;
using
GDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
GDataType
>
;
using
DDataType
=
remove_cvref_t
<
typename
TypeConfig
::
DDataType
>
;
using
DDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
DDataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
TypeConfig
::
AccDataType
>
;
using
AccDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
TypeConfig
::
ODataType
>
;
using
ODataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
ODataType
>
;
using
AScaleDataType
=
remove_cvref_t
<
typename
TypeConfig
::
AScaleDataType
>
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
AScaleDataType
>
;
using
GScaleDataType
=
remove_cvref_t
<
typename
TypeConfig
::
GScaleDataType
>
;
using
GScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
GScaleDataType
>
;
using
DScaleDataType
=
remove_cvref_t
<
typename
TypeConfig
::
DScaleDataType
>
;
using
DScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
DScaleDataType
>
;
using
YSmoothScaleDataType
=
remove_cvref_t
<
typename
TypeConfig
::
YSmoothScaleDataType
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
YSmoothScaleDataType
>
;
using
TopkWeightDataType
=
remove_cvref_t
<
typename
TypeConfig
::
TopkWeightDataType
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
TopkWeightDataType
>
;
using
IndexDataType
=
remove_cvref_t
<
typename
TypeConfig
::
IndexDataType
>
;
using
IndexDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
IndexDataType
>
;
static
constexpr
index_t
BT_
=
BlockTIle_
::
at
(
number
<
0
>
{});
// block token
static
constexpr
ck_tile
::
index_t
BT_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
0
>
{});
// block token
static
constexpr
index_t
BI_
=
BlockTIle_
::
at
(
number
<
1
>
{});
// block intermediate
static
constexpr
ck_tile
::
index_t
BI_
=
static
constexpr
index_t
BH_
=
BlockTIle_
::
at
(
number
<
2
>
{});
// block hidden
BlockTIle_
::
at
(
ck_tile
::
number
<
1
>
{});
// block intermediate
static
constexpr
index_t
BD_
=
BlockTIle_
::
at
(
number
<
3
>
{});
// block down
static
constexpr
ck_tile
::
index_t
BH_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
2
>
{});
// block hidden
static
constexpr
ck_tile
::
index_t
BD_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
3
>
{});
// block down
using
BlockTile_0
=
ck_tile
::
sequence
<
BT_
,
BI_
,
BH_
>
;
using
BlockTile_0
=
ck_tile
::
sequence
<
BT_
,
BI_
,
BH_
>
;
using
WarpPerBlock_0
=
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_0
=
remove_cvref_t
<
WarpTile_
>
;
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
/
(
GateOnly_
?
1
:
2
)
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
/
(
GateOnly_
?
1
:
2
)
>
;
using
WarpPerBlock_1
=
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_1
=
remove_cvref_t
<
WarpTile_
>
;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
static
constexpr
ck_tile
::
index_t
FusedQuant
=
FusedQuant_
;
static
constexpr
ck_tile
::
index_t
FusedQuant
=
FusedQuant_
;
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
0 → 100644
View file @
cf646183
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
16
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
example/ck_tile/15_fused_moe/main.cpp
View file @
cf646183
...
@@ -28,7 +28,7 @@ auto get_elimit<ck_tile::bf16_t>()
...
@@ -28,7 +28,7 @@ auto get_elimit<ck_tile::bf16_t>()
template
<
typename
T
>
template
<
typename
T
>
auto
shuffle_moe_weight
(
const
ck_tile
::
HostTensor
<
T
>&
t
,
std
::
string
mfma_dtype
,
int
mfma_type
=
0
)
auto
shuffle_moe_weight
(
const
ck_tile
::
HostTensor
<
T
>&
t
,
std
::
string
mfma_dtype
,
int
mfma_type
=
0
)
{
{
static_
assert
(
t
.
get_lengths
().
size
()
==
3
);
assert
(
t
.
get_lengths
().
size
()
==
3
);
int
b_
=
t
.
get_lengths
()[
0
];
int
b_
=
t
.
get_lengths
()[
0
];
int
n_
=
t
.
get_lengths
()[
1
];
int
n_
=
t
.
get_lengths
()[
1
];
int
k_
=
t
.
get_lengths
()[
2
];
int
k_
=
t
.
get_lengths
()[
2
];
...
@@ -152,11 +152,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -152,11 +152,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
index_t
shared_intermediate_size
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
ck_tile
::
index_t
shared_intermediate_size
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
GDataType
=
typename
TypeConfig
::
GDataType
;
using
GDataType
=
typename
TypeConfig
::
GDataType
;
using
DDataType
=
typename
TypeConfig
::
DDataType
;
using
DDataType
=
typename
TypeConfig
::
DDataType
;
using
AccDataType
=
typename
TypeConfig
::
AccDataType
;
//
using AccDataType = typename TypeConfig::AccDataType;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
using
AScaleDataType
=
typename
TypeConfig
::
AScaleDataType
;
using
AScaleDataType
=
typename
TypeConfig
::
AScaleDataType
;
using
GScaleDataType
=
typename
TypeConfig
::
GScaleDataType
;
using
GScaleDataType
=
typename
TypeConfig
::
GScaleDataType
;
...
@@ -167,8 +167,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -167,8 +167,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify
// host verify
ck_tile
::
HostTensor
<
ADataType
>
a_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
ADataType
>
a_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
GDataType
>
g_host
({
e
,
shared_intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
GDataType
>
g_host
({
e
xperts
,
shared_intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
({
e
,
intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
({
e
xperts
,
intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
ODataType
>
o_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
ODataType
>
o_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
AScaleDataType
>
sa_host
({
tokens
});
ck_tile
::
HostTensor
<
AScaleDataType
>
sa_host
({
tokens
});
ck_tile
::
HostTensor
<
GScaleDataType
>
sg_host
({
shared_intermediate_size
});
ck_tile
::
HostTensor
<
GScaleDataType
>
sg_host
({
shared_intermediate_size
});
...
@@ -200,7 +200,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -200,7 +200,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// do moe sorting
// do moe sorting
if
(
balance
)
if
(
balance
)
{
{
int
e_cnt
=
0
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
topk_ids_host
.
mData
.
size
());
i
++
)
int
e_cnt
=
0
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
topk_ids_host
.
mData
.
size
());
i
++
)
{
{
topk_ids_host
.
mData
[
i
]
=
e_cnt
;
topk_ids_host
.
mData
[
i
]
=
e_cnt
;
e_cnt
++
;
e_cnt
++
;
...
@@ -210,7 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -210,7 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
else
else
{
{
topid_unique_gen
<
IndexType
>
(
topk_ids_host
.
mData
,
tokens
,
topk
,
experts
,
11913
);
topid_unique_gen
<
Index
Data
Type
>
(
topk_ids_host
.
mData
,
tokens
,
topk
,
experts
,
11913
);
}
}
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
...
@@ -245,7 +246,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -245,7 +246,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
base_str
+=
"="
+
prec_o
;
base_str
+=
"="
+
prec_o
;
if
(
fused_quant
!=
0
)
if
(
fused_quant
!=
0
)
{
{
base_str
+=
std
::
string
(
"("
)
+
prec_s
a
+
"|"
+
prec_s
g
+
"|"
+
prec_sq
+
")"
;
base_str
+=
std
::
string
(
"("
)
+
prec_s
t
+
"|"
+
prec_s
w
+
"|"
+
prec_sq
+
")"
;
}
}
return
base_str
;
return
base_str
;
}();
}();
...
@@ -268,14 +269,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -268,14 +269,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_moegemm_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_moegemm_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
g_buf
.
GetDeviceBuffer
(),
g_perm_buf
.
GetDeviceBuffer
(),
d_buf
.
GetDeviceBuffer
(),
d_perm_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
fused_quant
!=
0
?
sg_buf
.
GetDeviceBuffer
()
:
nullptr
,
?
sg_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sd_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
?
sd_buf
.
GetDeviceBuffer
(),
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
(),
o_buf
.
GetDeviceBuffer
(),
o_buf
.
GetDeviceBuffer
(),
sorted_token_ids_buf
.
GetDeviceBuffer
(),
sorted_token_ids_buf
.
GetDeviceBuffer
(),
sorted_weight_buf
.
GetDeviceBuffer
(),
sorted_weight_buf
.
GetDeviceBuffer
(),
...
@@ -283,9 +281,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -283,9 +281,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
hidden_size
,
hidden_size
,
intermediate_size
,
intermediate_size
,
num_
tokens
,
tokens
,
experts
,
experts
,
stride
};
topk
,
stride
};
float
ave_time
=
fused_moegemm
(
float
ave_time
=
fused_moegemm
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
...
@@ -473,50 +472,24 @@ int main(int argc, char* argv[])
...
@@ -473,50 +472,24 @@ int main(int argc, char* argv[])
return
-
1
;
return
-
1
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_sx
=
arg_parser
.
get_str
(
"prec_sx"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_sy
=
arg_parser
.
get_str
(
"prec_sy"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
if
(
prec_o
==
"auto"
)
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
{
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_o
=
prec_i
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
}
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
if
(
prec_sx
==
"auto"
)
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
{
prec_sx
=
"fp32"
;
}
if
(
prec_sy
==
"auto"
)
{
prec_sy
=
"fp32"
;
}
int
save_mv
=
arg_parser
.
get_int
(
"save_mv"
);
// no dynamic quant case
// no dynamic quant case
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
if
(
prec_i
==
"bf16"
&&
prec_w
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_kw
==
"fp32"
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
// dynamic quant case, only in inference
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"int8"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
int8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"int8"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
>
(
arg_parser
)
?
0
:
-
2
;
}
}
return
-
3
;
return
-
3
;
...
...
example/ck_tile/CMakeLists.txt
View file @
cf646183
...
@@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax)
...
@@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax)
add_subdirectory
(
10_rmsnorm2d
)
add_subdirectory
(
10_rmsnorm2d
)
add_subdirectory
(
11_add_rmsnorm2d_rdquant
)
add_subdirectory
(
11_add_rmsnorm2d_rdquant
)
add_subdirectory
(
12_smoothquant
)
add_subdirectory
(
12_smoothquant
)
add_subdirectory
(
15_fused_moe
)
include/ck_tile/core/tensor/buffer_view.hpp
View file @
cf646183
...
@@ -635,7 +635,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -635,7 +635,7 @@ struct buffer_view<address_space_enum::global,
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
atomic_add_raw
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
atomic_add_raw
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
{
using
scalar_t
=
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
;
//
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -647,24 +647,6 @@ struct buffer_view<address_space_enum::global,
...
@@ -647,24 +647,6 @@ struct buffer_view<address_space_enum::global,
static_assert
(
get_address_space
()
==
address_space_enum
::
global
,
"only support global mem"
);
static_assert
(
get_address_space
()
==
address_space_enum
::
global
,
"only support global mem"
);
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
||
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
)
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
bf16_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
)
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
bf16_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add_raw
<
remove_cvref_t
<
T
>
,
amd_buffer_atomic_add_raw
<
remove_cvref_t
<
T
>
,
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
cf646183
...
@@ -68,6 +68,24 @@ CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
...
@@ -68,6 +68,24 @@ CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
return
tile_window
.
load
(
dst_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
dst_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
DistributedTensor_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
DistributedTensor_
&
dst_tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
dst_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
/**
/**
* @brief Loads a tile of data using inline assembly.
* @brief Loads a tile of data using inline assembly.
*
*
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
cf646183
...
@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
...
@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return
unpacks
;
return
unpacks
;
}
}
namespace
detail
{
// check if 2 static_distributed_tensor has same data type and size of element
// but only difference in distribution
template
<
typename
X
,
typename
Y
>
struct
is_similiar_distributed_tensor
{
static
constexpr
bool
value
=
false
;
};
template
<
typename
TypeX
,
typename
DistX
,
typename
TypeY
,
typename
DistY
>
struct
is_similiar_distributed_tensor
<
static_distributed_tensor
<
TypeX
,
DistX
>
,
static_distributed_tensor
<
TypeY
,
DistY
>>
{
using
Tx
=
static_distributed_tensor
<
TypeX
,
DistX
>
;
using
Ty
=
static_distributed_tensor
<
TypeY
,
DistY
>
;
static
constexpr
bool
value
=
std
::
is_same_v
<
typename
Tx
::
DataType
,
typename
Ty
::
DataType
>
&&
Tx
::
get_thread_buffer_size
()
==
Ty
::
get_thread_buffer_size
();
};
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_similiar_distributed_tensor_v
=
is_similiar_distributed_tensor
<
X
,
Y
>::
value
;
}
// namespace detail
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_window.hpp
View file @
cf646183
...
@@ -834,7 +834,7 @@ struct tile_window_with_static_distribution
...
@@ -834,7 +834,7 @@ struct tile_window_with_static_distribution
0
,
0
,
vec_value
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
);
bool_constant
<
pre_nop
>
{}
);
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
View file @
cf646183
...
@@ -509,6 +509,64 @@ struct tile_window_linear
...
@@ -509,6 +509,64 @@ struct tile_window_linear
return
dst_tensor
;
return
dst_tensor
;
}
}
template
<
typename
DstTile
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
DstTile
&
dst_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
// data index [y0, y1, ...]
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_index
(
IAccess
);
// write into distributed tensor
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_diff_ys
[
jj
]
+
j
)
:
idx_diff_ys
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
traits
::
ScalarPerVector
==
0
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
#endif
};
WINDOW_DISPATCH_ISSUE
();
return
dst_tensor
;
}
template
<
typename
DstTile
,
template
<
typename
DstTile
,
index_t
i_access
=
-
1
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
...
...
include/ck_tile/core/tensor/update_tile.hpp
View file @
cf646183
...
@@ -84,6 +84,7 @@ template <typename BottomTensorView_,
...
@@ -84,6 +84,7 @@ template <typename BottomTensorView_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
LinearBottomDims_
,
typename
DataType_
,
index_t
i_access
=
-
1
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
...
...
include/ck_tile/host/device_memory.hpp
View file @
cf646183
...
@@ -37,7 +37,7 @@ struct DeviceMem
...
@@ -37,7 +37,7 @@ struct DeviceMem
mpDeviceBuf
=
nullptr
;
mpDeviceBuf
=
nullptr
;
}
}
}
}
template
<
T
>
template
<
typename
T
>
DeviceMem
(
const
HostTensor
<
T
>&
t
)
:
mMemSize
(
t
.
get_element_space_size_in_bytes
())
DeviceMem
(
const
HostTensor
<
T
>&
t
)
:
mMemSize
(
t
.
get_element_space_size_in_bytes
())
{
{
if
(
mMemSize
!=
0
)
if
(
mMemSize
!=
0
)
...
@@ -109,18 +109,23 @@ struct DeviceMem
...
@@ -109,18 +109,23 @@ struct DeviceMem
// construct a host tensor with type T
// construct a host tensor with type T
template
<
typename
T
>
template
<
typename
T
>
HostTensor
<
T
>
ToHost
(
std
::
size_t
cpySize
=
mMemSize
)
HostTensor
<
T
>
ToHost
(
std
::
size_t
cpySize
)
{
{
// TODO: host tensor could be slightly larger than the device tensor
// TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer
// we just copy all data from GPU buffer
std
::
size_t
host_elements
=
std
::
size_t
host_elements
=
(
cpySize
+
sizeof
(
T
)
-
1
)
/
sizeof
(
T
);
(
cpySize
+
sizeof
(
T
)
-
1
)
/
sizeof
(
T
)
HostTensor
<
T
>
h_
({
host_elements
});
HostTensor
<
T
>
h_
({
host_elements
});
if
(
mpDeviceBuf
)
if
(
mpDeviceBuf
)
{
{
HIP_CHECK_ERROR
(
hipMemcpy
(
h_
.
data
(),
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
HIP_CHECK_ERROR
(
hipMemcpy
(
h_
.
data
(),
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
}
return
h_
;
return
h_
;
}
}
template
<
typename
T
>
HostTensor
<
T
>
ToHost
()
{
return
ToHost
<
T
>
(
mMemSize
);
}
void
SetZero
()
const
void
SetZero
()
const
{
{
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
cf646183
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
// [indexing implementation-1]
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, top
_
k=3, M_a = 4, input_tokens = 5
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
//
// max_num_tokens_padded : top
_
k * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
// * this could be larger than actual, since actual tokens are on GPU
//
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
@@ -102,7 +102,7 @@ struct FusedMoeGemmHostArgs
...
@@ -102,7 +102,7 @@ struct FusedMoeGemmHostArgs
index_t
intermediate_size
;
// n (TP slice this)
index_t
intermediate_size
;
// n (TP slice this)
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
index_t
num_experts
;
// number of groups
//
index_t top
_
k; // need this?
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
};
...
@@ -111,14 +111,14 @@ struct FusedMoeGemmHostArgs
...
@@ -111,14 +111,14 @@ struct FusedMoeGemmHostArgs
template
<
typename
Partitioner_
,
typename
Pipeline_
,
typename
Epilogue_
>
template
<
typename
Partitioner_
,
typename
Pipeline_
,
typename
Epilogue_
>
struct
FusedMoeGemmKernel
struct
FusedMoeGemmKernel
{
{
using
Partitioner
=
remove_cvref_t
<
Partitioner_
>
;
using
Partitioner
=
remove_cvref_t
<
Partitioner_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
// TODO: not used
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
// TODO: not used
static
constexpr
index_t
kBlockSize
=
Pipeline
::
kBlockSize
;
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0);
// static_assert(kBlockPerCu > 0);
using
BlockShape
=
typename
Pipeline
::
BlockShape
;
// this is FusedMoeGemmShape
using
BlockShape
=
typename
Pipeline
::
BlockShape
;
// this is FusedMoeGemmShape
static
constexpr
index_t
BlockSize_
=
BlockShape
::
BlockSize
;
using
ADataType
=
typename
Pipeline
::
Problem
::
ADataType
;
using
ADataType
=
typename
Pipeline
::
Problem
::
ADataType
;
using
GDataType
=
typename
Pipeline
::
Problem
::
GDataType
;
using
GDataType
=
typename
Pipeline
::
Problem
::
GDataType
;
...
@@ -154,7 +154,7 @@ struct FusedMoeGemmKernel
...
@@ -154,7 +154,7 @@ struct FusedMoeGemmKernel
{
{
// sync with generate.py
// sync with generate.py
// clang-format off
// clang-format off
return
""
;
// clang-format on
// clang-format on
}
}
...
@@ -178,7 +178,7 @@ struct FusedMoeGemmKernel
...
@@ -178,7 +178,7 @@ struct FusedMoeGemmKernel
index_t
intermediate_size
;
// n (TP slice this)
index_t
intermediate_size
;
// n (TP slice this)
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
index_t
num_experts
;
// number of groups
//
index_t top
_
k; // need this?
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
};
...
@@ -193,16 +193,20 @@ struct FusedMoeGemmKernel
...
@@ -193,16 +193,20 @@ struct FusedMoeGemmKernel
return
bit_cast
<
Kargs
>
(
hargs
);
return
bit_cast
<
Kargs
>
(
hargs
);
}
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
num_cu
,
index_t
blocks_per_cu
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
{
return
TilePartitioner
::
GridSize
(
num_cu
,
blocks_per_cu
);
constexpr
index_t
block_m
=
BlockShape
::
Block_M0
;
int
max_num_tokens_padded
=
hargs
.
topk
*
hargs
.
num_tokens
+
hargs
.
num_experts
*
(
block_m
-
1
);
return
Partitioner
::
GridSize
(
max_num_tokens_padded
,
hargs
.
intermediate_size
);
}
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
k
BlockSize
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
BlockSize
_
);
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
return
max
(
Pipeline
::
GetSmemSize
(),
Epilogue
::
GetSmemSize
());
// return max(Pipeline::GetSmemSize(), Epilogue::GetSmemSize());
return
Pipeline
::
GetSmemSize
();
}
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
...
@@ -213,10 +217,10 @@ struct FusedMoeGemmKernel
...
@@ -213,10 +217,10 @@ struct FusedMoeGemmKernel
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
index_t
nr_0
=
kargs
.
intermediate_size
/
Pipelin
e
::
Block_Nr0
;
index_t
nr_0
=
kargs
.
intermediate_size
/
BlockShap
e
::
Block_Nr0
;
index_t
kr_0
=
kargs
.
hidden_size
/
Pipelin
e
::
Block_Kr0
;
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShap
e
::
Block_Kr0
;
index_t
nr_1
=
kargs
.
hidden_size
/
Pipelin
e
::
Block_Nr1
;
// should be same as kr_0
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShap
e
::
Block_Nr1
;
// should be same as kr_0
index_t
kr_1
=
kargs
.
intermediate_size
/
Pipelin
e
::
Block_Kr1
;
// should be same as nr_0
index_t
kr_1
=
kargs
.
intermediate_size
/
BlockShap
e
::
Block_Kr1
;
// should be same as nr_0
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
...
@@ -224,8 +228,8 @@ struct FusedMoeGemmKernel
...
@@ -224,8 +228,8 @@ struct FusedMoeGemmKernel
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
// note this is in unit of tile, need multiple tile size to get the index
// note this is in unit of tile, need multiple tile size to get the index
const
auto
[
sorted_tile_id
,
hidden
_tile_id
]
=
const
auto
[
sorted_tile_id
,
intermediate
_tile_id
]
=
Tile
Partitioner
{}(
num_sorted_tiles
,
kargs
.
intermediate_size
);
Partitioner
{}(
num_sorted_tiles
,
kargs
.
intermediate_size
);
if
(
sorted_tile_id
>=
num_sorted_tiles
)
if
(
sorted_tile_id
>=
num_sorted_tiles
)
return
;
return
;
...
@@ -233,9 +237,10 @@ struct FusedMoeGemmKernel
...
@@ -233,9 +237,10 @@ struct FusedMoeGemmKernel
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
// index along intermediate_size
// index along intermediate_size
index_t
hidden_idx
=
__builtin_amdgcn_readfirstlane
(
hidden_tile_id
*
BlockShape
::
Block_N0
);
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
index_t
hidden_idx_nr
=
// BlockShape::Block_N0);
__builtin_amdgcn_readfirstlane
(
hidden_tile_id
*
BlockShape
::
Block_Nr0
);
index_t
interm_idx_nr
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Nr0
);
const
auto
a_coord
=
Pipeline
::
GetACoord
();
// 2d thread offset, [i_row, i_col]
const
auto
a_coord
=
Pipeline
::
GetACoord
();
// 2d thread offset, [i_row, i_col]
const
auto
sorted_token_id
=
a_coord
[
number
<
0
>
{}]
+
sorted_tile_id
*
BlockShape
::
Block_M0
;
const
auto
sorted_token_id
=
a_coord
[
number
<
0
>
{}]
+
sorted_tile_id
*
BlockShape
::
Block_M0
;
...
@@ -265,7 +270,7 @@ struct FusedMoeGemmKernel
...
@@ -265,7 +270,7 @@ struct FusedMoeGemmKernel
const
auto
a_window_
=
make_tile_window
(
const
auto
a_window_
=
make_tile_window
(
a_gather_view_
,
a_gather_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
Pipelin
e
::
Block_K0
>
{}),
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShap
e
::
Block_K0
>
{}),
{
0
,
0
});
{
0
,
0
});
return
a_window_
;
return
a_window_
;
}();
}();
...
@@ -274,61 +279,59 @@ struct FusedMoeGemmKernel
...
@@ -274,61 +279,59 @@ struct FusedMoeGemmKernel
const
auto
g_window
=
[
&
]()
{
const
auto
g_window
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
hidden
_idx_nr
*
kr_0
*
BlockShape
::
Block_W0
;
interm
_idx_nr
*
kr_0
*
BlockShape
::
Block_W0
;
const
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
Pipelin
e
::
Block_W0
>
{}),
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShap
e
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
Pipelin
e
::
Block_W0
>
{},
1
),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShap
e
::
Block_W0
>
{},
1
),
number
<
Pipeline
::
kAlignmentG
>
{},
number
<
Pipeline
::
kAlignmentG
>
{},
number
<
1
>
{});
number
<
1
>
{});
const
auto
g_view_1_
=
const
auto
g_view_1_
=
pad_tensor_view
(
g_view_
,
pad_tensor_view
(
g_view_
,
make_tuple
(
number
<
Pipelin
e
::
Block_Nr0
>
{},
make_tuple
(
number
<
BlockShap
e
::
Block_Nr0
>
{},
number
<
Pipelin
e
::
Block_Kr0
>
{},
number
<
BlockShap
e
::
Block_Kr0
>
{},
number
<
Pipelin
e
::
Block_W0
>
{}),
number
<
BlockShap
e
::
Block_W0
>
{}),
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
const
auto
g_window_
=
make_tile_window
(
g_view_1_
,
const
auto
g_window_
=
make_tile_window
(
g_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
Pipelin
e
::
Block_Kr0
>
{},
number
<
BlockShap
e
::
Block_Kr0
>
{},
number
<
Pipelin
e
::
Block_W0
>
{}),
number
<
BlockShap
e
::
Block_W0
>
{}),
{
0
,
0
,
0
});
{
0
,
0
,
0
});
return
g_window_
;
return
g_window_
;
}();
}();
const
auto
d_window
=
[
&
]()
{
const
auto
d_window
=
[
&
]()
{
const
DDataType
*
d_ptr
=
[
&
]()
{
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
interm_idx_nr
*
BlockShape
::
Block_W1
;
hidden_idx_nr
*
BlockShape
::
Block_W1
;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
// note hidden_idx_nr is along the gemm-k dim of 2nd gemm
}();
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
Pipelin
e
::
Block_W1
),
make_tuple
(
nr_1
,
kr_1
,
BlockShap
e
::
Block_W1
),
make_tuple
(
kr_1
*
Pipelin
e
::
Block_W1
,
Pipelin
e
::
Block_W1
,
1
),
make_tuple
(
kr_1
*
BlockShap
e
::
Block_W1
,
BlockShap
e
::
Block_W1
,
1
),
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
number
<
1
>
{});
const
auto
d_view_1_
=
const
auto
d_view_1_
=
pad_tensor_view
(
d_view_
,
pad_tensor_view
(
d_view_
,
make_tuple
(
number
<
Pipelin
e
::
k
BlockNr
_
1
>
{},
make_tuple
(
number
<
BlockShap
e
::
Block
_
Nr1
>
{},
number
<
Pipelin
e
::
k
BlockKr
_
1
>
{},
number
<
BlockShap
e
::
Block
_
Kr1
>
{},
number
<
Pipelin
e
::
Block_W1
>
{}),
number
<
BlockShap
e
::
Block_W1
>
{}),
sequence
<
PadHiddenSize
,
PadIntermediateSize
,
0
>
{});
sequence
<
PadHiddenSize
,
PadIntermediateSize
,
0
>
{});
const
auto
d_window_
=
make_tile_window
(
d_view_1_
,
const
auto
d_window_
=
make_tile_window
(
d_view_1_
,
make_tuple
(
number
<
Pipelin
e
::
k
BlockNr
_
1
>
{},
make_tuple
(
number
<
BlockShap
e
::
Block
_
Nr1
>
{},
number
<
Pipelin
e
::
k
BlockKr
_
1
>
{},
number
<
BlockShap
e
::
Block
_
Kr1
>
{},
number
<
Pipelin
e
::
Block_W1
>
{}),
number
<
BlockShap
e
::
Block_W1
>
{}),
{
0
,
0
,
0
});
{
0
,
0
,
0
});
return
d_window_
;
return
d_window_
;
}();
}();
auto
o_window
=
[
&
]()
{
auto
o_window
=
[
&
]()
{
const
ODataType
*
o_ptr
=
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
);
ODataType
*
o_ptr
=
reinterpret_cast
<
ODataType
*>
(
kargs
.
o_ptr
);
const
auto
o_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
,
auto
o_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
memory_operation_enum
::
atomic_add
>
(
o_ptr
,
o_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
stride_token
,
1
),
make_tuple
(
kargs
.
stride_token
,
1
),
...
@@ -336,16 +339,16 @@ struct FusedMoeGemmKernel
...
@@ -336,16 +339,16 @@ struct FusedMoeGemmKernel
number
<
1
>
{});
number
<
1
>
{});
// gather is here
// gather is here
const
auto
o_scatter_view_
=
transform_tensor_view
(
auto
o_scatter_view_
=
transform_tensor_view
(
o_view_
,
o_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
o_window_
=
make_tile_window
(
auto
o_window_
=
make_tile_window
(
o_scatter_view_
,
o_scatter_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M1
>
{},
number
<
Pipelin
e
::
Block_N1
>
{}),
make_tuple
(
number
<
BlockShape
::
Block_M1
>
{},
number
<
BlockShap
e
::
Block_N1
>
{}),
{
0
,
0
});
{
0
,
0
});
return
o_window_
;
return
o_window_
;
}();
}();
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
View file @
cf646183
...
@@ -58,14 +58,15 @@ struct FusedMoeGemmShape
...
@@ -58,14 +58,15 @@ struct FusedMoeGemmShape
static
constexpr
index_t
NumWarps
=
static
constexpr
index_t
NumWarps
=
reduce_on_sequence
(
WarpPerBlock_0
{},
multiplies
{},
number
<
1
>
{});
reduce_on_sequence
(
WarpPerBlock_0
{},
multiplies
{},
number
<
1
>
{});
// TODO: we don't support half warps aound to 1 warp here
static_assert
(
NumWarps
==
reduce_on_sequence
(
WarpPerBlock_1
{},
multiplies
{},
number
<
1
>
{}));
static_assert
(
NumWarps
==
reduce_on_sequence
(
WarpPerBlock_1
{},
multiplies
{},
number
<
1
>
{}));
static
constexpr
index_t
Block_M0
=
BlockTile_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_M0
=
BlockTile_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N0
=
BlockTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_N0
=
BlockTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_K0
=
BlockTile_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Block_K0
=
BlockTile_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_M0
=
WarpPerBlock_0
::
at
(
num
n
er
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_M0
=
WarpPerBlock_0
::
at
(
num
b
er
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N0
=
WarpPerBlock_0
::
at
(
num
n
er
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_N0
=
WarpPerBlock_0
::
at
(
num
b
er
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_K0
=
WarpPerBlock_0
::
at
(
num
n
er
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_K0
=
WarpPerBlock_0
::
at
(
num
b
er
<
2
>
{});
static
constexpr
index_t
Warp_M0
=
WarpTile_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_M0
=
WarpTile_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N0
=
WarpTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_N0
=
WarpTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_K0
=
WarpTile_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warp_K0
=
WarpTile_0
::
at
(
number
<
2
>
{});
...
@@ -83,12 +84,12 @@ struct FusedMoeGemmShape
...
@@ -83,12 +84,12 @@ struct FusedMoeGemmShape
static
constexpr
index_t
Block_M1
=
BlockTile_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_M1
=
BlockTile_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N1
=
BlockTile_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_N1
=
BlockTile_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_K1
=
BlockTile_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Block_K1
=
BlockTile_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_M1
=
Warp
Tile
_1
::
at
(
num
n
er
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_M1
=
Warp
PerBlock
_1
::
at
(
num
b
er
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N1
=
Warp
Tile
_1
::
at
(
num
n
er
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_N1
=
Warp
PerBlock
_1
::
at
(
num
b
er
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_K1
=
Warp
Tile
_1
::
at
(
num
n
er
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_K1
=
Warp
PerBlock
_1
::
at
(
num
b
er
<
2
>
{});
static
constexpr
index_t
Warp_M1
=
Warp
PerBlock
_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_M1
=
Warp
Tile
_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N1
=
Warp
PerBlock
_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_N1
=
Warp
Tile
_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_K1
=
Warp
PerBlock
_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warp_K1
=
Warp
Tile
_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
ThreadPerBlock_M1
=
Warp_M1
*
WarpPerBlock_M1
;
static
constexpr
index_t
ThreadPerBlock_M1
=
Warp_M1
*
WarpPerBlock_M1
;
static
constexpr
index_t
ThreadPerBlock_N1
=
Warp_N1
*
WarpPerBlock_N1
;
static
constexpr
index_t
ThreadPerBlock_N1
=
Warp_N1
*
WarpPerBlock_N1
;
...
@@ -119,6 +120,6 @@ struct FusedMoeGemmShape
...
@@ -119,6 +120,6 @@ struct FusedMoeGemmShape
static
constexpr
index_t
Block_Kr1
=
Block_K1
/
Warp_K1
;
static
constexpr
index_t
Block_Kr1
=
Block_K1
/
Warp_K1
;
static_assert
(
Block_W0
==
Block_W1
);
static_assert
(
Block_W0
==
Block_W1
);
static_assert
(
Block_Nr0
==
Block_Kr1
);
//
static_assert(Block_Nr0 == Block_Kr1);
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
View file @
cf646183
...
@@ -11,10 +11,10 @@ struct FusedMoeGemmTilePartitioner_Linear
...
@@ -11,10 +11,10 @@ struct FusedMoeGemmTilePartitioner_Linear
// FusedMoeGemmShape
// FusedMoeGemmShape
using
BlockShape
=
ck_tile
::
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
ck_tile
::
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
const
char
*
name
=
"
eh"
;
// expert x hidden
static
constexpr
const
char
*
name
=
"
lin"
;
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*num_sorted_tiles*/
,
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*num_sorted_tiles*/
,
ck_tile
::
index_t
/*
hidden
_size*/
)
)
ck_tile
::
index_t
/*
intermediate
_size*/
)
{
{
index_t
i_n
=
blockIdx
.
x
;
index_t
i_n
=
blockIdx
.
x
;
index_t
i_m
=
blockIdx
.
y
;
index_t
i_m
=
blockIdx
.
y
;
...
@@ -22,11 +22,11 @@ struct FusedMoeGemmTilePartitioner_Linear
...
@@ -22,11 +22,11 @@ struct FusedMoeGemmTilePartitioner_Linear
return
ck_tile
::
make_tuple
(
i_m
,
i_n
);
return
ck_tile
::
make_tuple
(
i_m
,
i_n
);
}
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
max_tokens
,
index_t
hidden
_size
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
max_tokens
,
index_t
intermediate
_size
)
{
{
// TODO: this may need tuning
// TODO: this may need tuning
index_t
ms
=
ck_tile
::
integer_divide_ceil
(
max_tokens
,
BlockShape
::
Block_M0
);
index_t
ms
=
ck_tile
::
integer_divide_ceil
(
max_tokens
,
BlockShape
::
Block_M0
);
index_t
ns
=
ck_tile
::
integer_divide_ceil
(
hidden
_size
,
BlockShape
::
Block_N0
);
index_t
ns
=
ck_tile
::
integer_divide_ceil
(
intermediate
_size
,
BlockShape
::
Block_N0
);
return
dim3
(
ns
,
ms
,
1
);
return
dim3
(
ns
,
ms
,
1
);
}
}
};
};
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
View file @
cf646183
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
View file @
cf646183
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
View file @
cf646183
...
@@ -35,9 +35,9 @@ struct WarpGemmImpl
...
@@ -35,9 +35,9 @@ struct WarpGemmImpl
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
bool_constant
<
post_nop_
>
=
{})
const
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
bool_constant
<
post_nop_
>
=
{})
const
{
{
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
CTensor
,
CTensor
>
&&
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
CTensor
,
C
Warp
Tensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
ATensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
A
Warp
Tensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BTensor
>
);
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
B
Warp
Tensor
>
);
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CTensor
::
get_thread_buffer_size
()
>
;
...
@@ -85,8 +85,8 @@ struct WarpGemmImpl
...
@@ -85,8 +85,8 @@ struct WarpGemmImpl
CK_TILE_DEVICE
auto
operator
()(
const
ATensor
&
a
,
const
BTensor
&
b
)
const
CK_TILE_DEVICE
auto
operator
()(
const
ATensor
&
a
,
const
BTensor
&
b
)
const
{
{
using
CTensor
=
CWarpTensor
;
using
CTensor
=
CWarpTensor
;
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
ATensor
>
&&
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
A
Warp
Tensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BTensor
>
);
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
B
Warp
Tensor
>
);
CTensor
c
;
CTensor
c
;
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
using
AVec
=
ext_vector_t
<
ADataType
,
ATensor
::
get_thread_buffer_size
()
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment