Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
aef2b33c
Commit
aef2b33c
authored
Jan 13, 2025
by
coderfeli
Browse files
build ok
parent
075a4a43
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
147 additions
and
325 deletions
+147
-325
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
+3
-3
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
+3
-3
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+1
-1
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+1
-0
include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp
...ile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp
+27
-1
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32.hpp
.../ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32.hpp
+72
-0
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
.../flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
+10
-10
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp
.../ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp
+1
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x256_1x4x1_16x16x16_itl.inc
...ck/uk/flatmm_sn_uk_gfx9_32x128x256_1x4x1_16x16x16_itl.inc
+0
-265
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
...sed_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
+4
-4
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+25
-37
No files found.
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
View file @
aef2b33c
...
@@ -11,9 +11,9 @@ template float fused_moegemm_<
...
@@ -11,9 +11,9 @@ template float fused_moegemm_<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
//
template float fused_moegemm_<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
1024
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
//
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
//
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template
float
fused_moegemm_
<
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
View file @
aef2b33c
...
@@ -10,9 +10,9 @@
...
@@ -10,9 +10,9 @@
template
float
fused_moegemm_
<
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
//
template float fused_moegemm_<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
1024
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
//
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
//
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template
float
fused_moegemm_
<
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
...
...
example/ck_tile/15_fused_moe/main.cpp
View file @
aef2b33c
...
@@ -304,7 +304,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -304,7 +304,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
// permute weight
// permute weight
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
gate_only
?
shuffle_moe_weight
(
g_host
,
prec_w
,
1
)
:
shuffle_moe_weight_gateup
(
g_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
// do moe sorting
// do moe sorting
...
...
include/ck_tile/ops/flatmm.hpp
View file @
aef2b33c
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp
View file @
aef2b33c
...
@@ -57,7 +57,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
...
@@ -57,7 +57,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
// TODO: note Nr/Kr/W need consider SubKPacks
// TODO: note Nr/Kr/W need consider SubKPacks
static
constexpr
index_t
Block_W
=
Warp_N
*
Warp_K
;
// 512 element
static
constexpr
index_t
Block_W
=
Warp_N
*
Warp_K
;
// 512 element
static
constexpr
index_t
Block_Nr
=
Block_N
/
Warp_N
;
// 32 element, 4 per wave
static
constexpr
index_t
Block_Nr
=
Block_N
/
Warp_N
;
// 32 element, 4 per wave
static
constexpr
index_t
Block_Kr
=
Block_K
/
Warp_K
;
//
4
static
constexpr
index_t
Block_Kr
=
Block_K
/
Warp_K
;
//
16
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
Warp_M
*
WarpPerBlock_M
);
// 2
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
Warp_M
*
WarpPerBlock_M
);
// 2
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
Warp_N
*
WarpPerBlock_N
);
// 8
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
Warp_N
*
WarpPerBlock_N
);
// 8
...
@@ -89,6 +89,32 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
...
@@ -89,6 +89,32 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
return
c_block_tensor
;
return
c_block_tensor
;
}
}
static
CK_TILE_DEVICE
constexpr
auto
MakeCBlockDistGUMerge
()
{
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Repeat_M
,
WarpPerBlock_M
>
,
sequence
<
Repeat_N
/
2
,
WarpPerBlock_N
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
2
,
1
>
,
// !! note here is different
sequence
<
0
,
0
>>
{};
using
WG
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
return
c_block_dstr
;
}
static
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTileGUMerge
()
{
using
CDataType
=
float
;
constexpr
auto
c_block_dstr
=
MakeCBlockDistGUMerge
();
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsStoreDesc_A
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsStoreDesc_A
()
{
{
// A async->LDS
// A async->LDS
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32.hpp
0 → 100644
View file @
aef2b33c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
namespace
ck_tile
{
// "S"tream update output along "N"
// A in smem, B load from global
// require 4 wave, occupancy=1c
struct
FlatmmSn_32x128x256_1x4x1_16x16x32_Base
{
static
constexpr
index_t
Block_M
=
32
;
static
constexpr
index_t
Block_N
=
128
;
static
constexpr
index_t
Block_K
=
256
;
static
constexpr
index_t
WarpPerBlock_M
=
1
;
static
constexpr
index_t
WarpPerBlock_N
=
4
;
static
constexpr
index_t
WarpPerBlock_K
=
1
;
static
constexpr
index_t
Warp_M
=
16
;
static
constexpr
index_t
Warp_N
=
16
;
static
constexpr
index_t
Warp_K
=
32
;
static
constexpr
index_t
BlockSize
=
256
;
// static constexpr index_t KPack = 2; // this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider KPack
static
constexpr
index_t
Block_W
=
Warp_N
*
Warp_K
;
// 512 element
static
constexpr
index_t
Block_Nr
=
Block_N
/
Warp_N
;
// 32 element, 4 per wave
static
constexpr
index_t
Block_Kr
=
Block_K
/
Warp_K
;
// 8
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
Warp_M
*
WarpPerBlock_M
);
// 2
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
Warp_N
*
WarpPerBlock_N
);
// 2
static
constexpr
index_t
Repeat_K
=
Block_K
/
(
Warp_K
*
WarpPerBlock_K
);
// 8
static
CK_TILE_DEVICE
constexpr
auto
MakeCBlockDist
()
{
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Repeat_M
,
WarpPerBlock_M
>
,
sequence
<
Repeat_N
,
WarpPerBlock_N
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
2
,
1
>
,
// !! note here is different
sequence
<
0
,
0
>>
{};
using
WG
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
return
c_block_dstr
;
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
return
2
*
2
*
4
*
4
*
(
16
*
4
+
4
)
*
sizeof
(
bf16_t
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
View file @
aef2b33c
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x
512
_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x
256
_1x4x1_16x16x32.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -14,7 +14,7 @@ namespace ck_tile {
...
@@ -14,7 +14,7 @@ namespace ck_tile {
// A in smem, B load from global
// A in smem, B load from global
// require 4 wave, occupancy=1c
// require 4 wave, occupancy=1c
struct
FlatmmSn_32x128x
512
_1x4x1_16x16x32_BF16_itl
:
public
FlatmmSn_32x128x
512
_1x4x1_16x16x32_Base
struct
FlatmmSn_32x128x
256
_1x4x1_16x16x32_BF16_itl
:
public
FlatmmSn_32x128x
256
_1x4x1_16x16x32_Base
{
{
using
BDataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
ODataType
=
bf16_t
;
using
ODataType
=
bf16_t
;
...
@@ -118,7 +118,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x512_
...
@@ -118,7 +118,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x512_
#pragma clang diagnostic ignored "-Winline-asm"
#pragma clang diagnostic ignored "-Winline-asm"
asm
volatile
(
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_sn_uk_gfx9_32x128x
512
_1x4x1_16x16x16_itl.inc"
#include "uk/flatmm_sn_uk_gfx9_32x128x
256
_1x4x1_16x16x16_itl.inc"
#undef CK_TILE_FLATMM_UK_MFMA
#undef CK_TILE_FLATMM_UK_MFMA
:
[
smem_
]
"+r"
(
smem
),
:
[
smem_
]
"+r"
(
smem
),
// [s_loop_cnt]"+s"(loop_cnt),
// [s_loop_cnt]"+s"(loop_cnt),
...
@@ -181,10 +181,10 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x512_
...
@@ -181,10 +181,10 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x512_
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
//
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
//
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
//
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
//
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
...
@@ -262,7 +262,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x512_
...
@@ -262,7 +262,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl : public FlatmmSn_32x128x512_
}
}
};
};
struct
FlatmmSn_32x128x
512
_1x4x1_16x16x32_FP16_itl
:
public
FlatmmSn_32x128x
512
_1x4x1_16x16x32_Base
struct
FlatmmSn_32x128x
256
_1x4x1_16x16x32_FP16_itl
:
public
FlatmmSn_32x128x
256
_1x4x1_16x16x32_Base
{
{
using
BDataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
ODataType
=
bf16_t
;
using
ODataType
=
bf16_t
;
...
@@ -288,7 +288,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_
...
@@ -288,7 +288,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_
index_t
tile_offset_b
,
// stride b is fixed to blockKr * blockW, but still can adjust
index_t
tile_offset_b
,
// stride b is fixed to blockKr * blockW, but still can adjust
index_t
tile_offset_o
)
index_t
tile_offset_o
)
{
{
static_assert
(
BCoords
::
size
()
==
8
);
// 8
static_assert
(
BCoords
::
size
()
==
4
);
// 8
static_assert
(
OCoords
::
size
()
==
8
);
static_assert
(
OCoords
::
size
()
==
8
);
const
index_t
tile_stride_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
const
index_t
tile_stride_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
...
@@ -365,7 +365,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_
...
@@ -365,7 +365,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_
#pragma clang diagnostic ignored "-Winline-asm"
#pragma clang diagnostic ignored "-Winline-asm"
asm
volatile
(
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_sn_uk_gfx9_32x128x
512
_1x4x1_16x16x16_itl.inc"
#include "uk/flatmm_sn_uk_gfx9_32x128x
256
_1x4x1_16x16x16_itl.inc"
#undef CK_TILE_FLATMM_UK_MFMA
#undef CK_TILE_FLATMM_UK_MFMA
:
[
smem_
]
"+r"
(
smem
),
:
[
smem_
]
"+r"
(
smem
),
[
s_loop_cnt
]
"+s"
(
n
),
[
s_loop_cnt
]
"+s"
(
n
),
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp
View file @
aef2b33c
...
@@ -33,7 +33,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_Base
...
@@ -33,7 +33,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_Base
// TODO: note Nr/Kr/W need consider KPack
// TODO: note Nr/Kr/W need consider KPack
static
constexpr
index_t
Block_W
=
Warp_N
*
Warp_K
;
// 512 element
static
constexpr
index_t
Block_W
=
Warp_N
*
Warp_K
;
// 512 element
static
constexpr
index_t
Block_Nr
=
Block_N
/
Warp_N
;
// 32 element, 4 per wave
static
constexpr
index_t
Block_Nr
=
Block_N
/
Warp_N
;
// 32 element, 4 per wave
static
constexpr
index_t
Block_Kr
=
Block_K
/
Warp_K
;
//
4
static
constexpr
index_t
Block_Kr
=
Block_K
/
Warp_K
;
//
16
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
Warp_M
*
WarpPerBlock_M
);
// 2
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
Warp_M
*
WarpPerBlock_M
);
// 2
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
Warp_N
*
WarpPerBlock_N
);
// 2
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
Warp_N
*
WarpPerBlock_N
);
// 2
...
...
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x256_1x4x1_16x16x16_itl.inc
View file @
aef2b33c
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
View file @
aef2b33c
...
@@ -807,16 +807,16 @@ struct FusedMoeGemmPipelineFlatmmPolicy
...
@@ -807,16 +807,16 @@ struct FusedMoeGemmPipelineFlatmmPolicy
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
bf16_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
256
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
256
&&
S_
::
Block_K0
==
128
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
&&
!
Problem
::
Traits
::
IsGateOnly
)
{
{
return
Flatmm_32x
256
x128_1x4x1_16x16x32_BF16
{};
return
Flatmm_32x
512
x128_1x4x1_16x16x32_BF16
{};
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
fp16_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
fp16_t
>
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
256
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
256
&&
S_
::
Block_K0
==
128
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
&&
!
Problem
::
Traits
::
IsGateOnly
)
{
{
return
Flatmm_32x
256
x128_1x4x1_16x16x32_FP16
{};
return
Flatmm_32x
512
x128_1x4x1_16x16x32_FP16
{};
}
}
}
}
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
aef2b33c
...
@@ -199,8 +199,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -199,8 +199,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
},
number
<
row_ids_a
.
size
()
>
{});
number
<
row_ids_a
.
size
()
>
{});
if
(
row_ids_a
[
0
]
>=
kargs
.
num_tokens
)
return
;
auto
a_res
=
auto
a_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
...
@@ -266,8 +265,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -266,8 +265,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto
d_coords
=
[
&
]()
{
auto
d_coords
=
[
&
]()
{
constexpr
index_t
Nr_
=
2
;
constexpr
index_t
Nr_
=
2
;
constexpr
index_t
Nw_
=
4
;
constexpr
index_t
Nw_
=
4
;
constexpr
index_t
Kr0_
=
BlockShape
::
Block_Kr1
/
Kr1_
;
//4
constexpr
index_t
Kr1_
=
4
;
constexpr
index_t
Kr1_
=
4
;
constexpr
index_t
Kr0_
=
BlockShape
::
Block_Kr1
/
Kr1_
;
//4
constexpr
index_t
Kl_
=
4
;
constexpr
index_t
Kl_
=
4
;
constexpr
index_t
Nl_
=
16
;
constexpr
index_t
Nl_
=
16
;
constexpr
index_t
Kv_
=
8
;
constexpr
index_t
Kv_
=
8
;
...
@@ -300,7 +299,9 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -300,7 +299,9 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto
bridge_sst_win
=
[
&
]()
{
auto
bridge_sst_win
=
[
&
]()
{
constexpr
auto
desc_
=
Policy
::
template
MakeBridgeLdsStoreForUKDesc
<
Problem
>();
constexpr
auto
desc_
=
Policy
::
template
MakeBridgeLdsStoreForUKDesc
<
Problem
>();
constexpr
auto
dist_
=
Policy
::
template
GetUK_0
<
Problem
>().
MakeCBlockDist
();
constexpr
auto
dist_
=
Policy
::
template
GetUK_0
<
Problem
>().
MakeCBlockDistGUMerge
();
// constexpr auto dist_ = IsGateOnly ? Policy::template GetUK_0<Problem>().MakeCBlockDist()
// : Policy::template GetUK_0<Problem>().MakeCBlockDistGUMerge();
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
YDataType
*>
(
smem
),
desc_
),
reinterpret_cast
<
YDataType
*>
(
smem
),
desc_
),
desc_
.
get_lengths
(),
desc_
.
get_lengths
(),
...
@@ -315,11 +316,11 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -315,11 +316,11 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto
w_scale
=
GetWeightScale
(
auto
w_scale
=
GetWeightScale
(
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
if
(
row_ids_a
[
0
]
>=
kargs
.
num_tokens
)
//
if (row_ids_a[0] >= kargs.num_tokens)
return
;
//
return;
auto
uk_0_g
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
uk_0_g
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
acc_0
=
uk_0_g
(
a_res
,
auto
acc_0
_full
=
uk_0_g
(
a_res
,
a_coords
,
a_coords
,
g_res
,
g_res
,
g_coords
,
g_coords
,
...
@@ -328,7 +329,13 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -328,7 +329,13 @@ struct FusedMoeGemmPipeline_FlatmmUk
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
// auto acc_0 = IsGateOnly ? acc_0_full : Policy::template GetUK_0<Problem>().MakeCBlockTileGUMerge();
auto
acc_0
=
Policy
::
template
GetUK_0
<
Problem
>().
MakeCBlockTileGUMerge
();
if
(
!
IsGateOnly
)
{
sweep_tile
(
acc_0
,
[
&
](
auto
idx0
)
{
acc_0
(
idx0
)
=
acc_0_full
(
idx0
);
});
}
// fast GeLu
// fast GeLu
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
GateActivation
,
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
GateActivation
,
ck_tile
::
element_wise
::
FastGeluAsm
>
)
ck_tile
::
element_wise
::
FastGeluAsm
>
)
...
@@ -350,37 +357,18 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -350,37 +357,18 @@ struct FusedMoeGemmPipeline_FlatmmUk
[
&
](
auto
idx0
)
{
typename
Problem
::
GateActivation
{}(
acc_0
(
idx0
),
acc_0
(
idx0
));
},
[
&
](
auto
idx0
)
{
typename
Problem
::
GateActivation
{}(
acc_0
(
idx0
),
acc_0
(
idx0
));
},
sequence
<
1
,
1
>
{});
sequence
<
1
,
1
>
{});
}
}
if
(
!
IsGateOnly
)
{
for
(
auto
i
=
0
;
i
<
BlockShape
::
Repeat_N0
;
i
++
)
{
acc_0
.
get_thread_buffer
()[
4
*
i
+
0
]
*=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
0
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
1
]
*=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
1
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
2
]
*=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
2
];
acc_0
.
get_thread_buffer
()[
4
*
i
+
3
]
*=
acc_0_full
.
get_thread_buffer
()[
4
*
(
i
+
BlockShape
::
Repeat_N0
)
+
3
];
}
}
auto
y_pre
=
acc_0
;
auto
y_pre
=
acc_0
;
block_sync_lds
();
block_sync_lds
();
// up
// if(!IsGateOnly)
// {
// // up ptr. add hafl expoert_stride_0 as offset.
// auto u_win = gu_win_gen(shared_intermediate_size_0 * kargs.hidden_size);
// auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
// auto u_coords =
// generate_tuple([&](auto i) { return u_win.cached_coords_[i].get_offset(); },
// number<decltype(u_win)::NumAccess_NonLinear>{});
// // reuse UK0
// auto uk_0_u = Policy::template GetUK_0<Problem>();
// auto acc_0_u = uk_0_u(a_res,
// a_coords,
// u_res,
// u_coords,
// smem,
// kargs.hidden_size,
// BlockShape::Block_K0, // tile offset for B matrix each unroll
// BlockShape::Block_Kr0 *
// BlockShape::Block_W0); // tile offset for B matrix each unroll
// // elementwise mul gate*up.
// sweep_tile(
// y_pre,
// [&](auto idx0) { y_pre(idx0) = y_pre(idx0) * acc_0_u(idx0); },
// sequence<1, 1>{});
// block_sync_lds();
// }
store_tile
(
bridge_sst_win
,
cast_tile
<
YDataType
>
(
y_pre
));
store_tile
(
bridge_sst_win
,
cast_tile
<
YDataType
>
(
y_pre
));
block_sync_lds
();
block_sync_lds
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment