Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
49c39b51
Commit
49c39b51
authored
Nov 05, 2024
by
carlushuang
Browse files
moe pipeline
parent
03c6448b
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2857 additions
and
29 deletions
+2857
-29
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp
..._tile/06_permute/alternative_impl/matrix_core_swizzle.cpp
+2
-2
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp
...6_permute/alternative_impl/matrix_core_swizzle_kernel.hpp
+6
-6
example/ck_tile/06_permute/permute.cpp
example/ck_tile/06_permute/permute.cpp
+1
-1
example/ck_tile/15_fused_moe/fused_moegemm.hpp
example/ck_tile/15_fused_moe/fused_moegemm.hpp
+66
-0
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+415
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+112
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+98
-6
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+27
-9
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+42
-0
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+67
-0
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+52
-0
include/ck_tile/core/tensor/update_tile.hpp
include/ck_tile/core/tensor/update_tile.hpp
+52
-3
include/ck_tile/host/reference/reference_permute.hpp
include/ck_tile/host/reference/reference_permute.hpp
+22
-2
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+14
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+362
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
+124
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
...e/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
+33
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
.../ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
+572
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
...sed_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
+744
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
+46
-0
No files found.
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp
View file @
49c39b51
...
@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
...
@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else
if
(
t
.
permute
.
compare
(
"0,1,3,4,2,5"
)
==
0
)
else
if
(
t
.
permute
.
compare
(
"0,1,3,4,2,5"
)
==
0
)
{
{
constexpr
matrix_core_permute_style
pstyle
=
constexpr
matrix_core_permute_style
pstyle
=
matrix_core_permute_style
::
permute_
b_nr_kr_kw_nw_kv
;
matrix_core_permute_style
::
b_nr_kr_kw_nw_kv
;
using
Kernel
=
using
Kernel
=
matrix_core_swizzle_kernel
<
BLOCK_SIZE
,
NPerBlock
,
KPerBlock
,
pstyle
,
Inst
>
;
matrix_core_swizzle_kernel
<
BLOCK_SIZE
,
NPerBlock
,
KPerBlock
,
pstyle
,
Inst
>
;
...
@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
...
@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else
if
(
t
.
permute
.
compare
(
"0,1,3,4,2,5"
)
==
0
)
else
if
(
t
.
permute
.
compare
(
"0,1,3,4,2,5"
)
==
0
)
{
{
constexpr
matrix_core_permute_style
pstyle
=
constexpr
matrix_core_permute_style
pstyle
=
matrix_core_permute_style
::
permute_
b_nr_kr_kw_nw_kv
;
matrix_core_permute_style
::
b_nr_kr_kw_nw_kv
;
using
Kernel
=
using
Kernel
=
matrix_core_swizzle_kernel
<
BLOCK_SIZE
,
NPerBlock
,
KPerBlock
,
pstyle
,
Inst
>
;
matrix_core_swizzle_kernel
<
BLOCK_SIZE
,
NPerBlock
,
KPerBlock
,
pstyle
,
Inst
>
;
...
...
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp
View file @
49c39b51
...
@@ -42,8 +42,8 @@ enum class matrix_core_permute_style
...
@@ -42,8 +42,8 @@ enum class matrix_core_permute_style
{
{
permute_b_n0_k0_n1_k1_n2_k2
=
0
,
// 0,1,4,2,5,3,6
permute_b_n0_k0_n1_k1_n2_k2
=
0
,
// 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2
=
1
,
// 0,1,2,4,5,3,6
permute_b_n0_n1_k0_k1_n2_k2
=
1
,
// 0,1,2,4,5,3,6
permute_
b_nr_kr_kw_nw_kv
=
2
,
// 0,1,3,4,2,5
b_nr_kr_kw_nw_kv
=
2
,
// 0,1,3,4,2,5
permute_
b_nr_kr_waveflatten
=
permute_
b_nr_kr_kw_nw_kv
,
b_nr_kr_waveflatten
=
b_nr_kr_kw_nw_kv
,
};
};
// assume this is B matrix, originally we have batch*n*k
// assume this is B matrix, originally we have batch*n*k
...
@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel
...
@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel
else
else
{
{
// clang-format off
// clang-format off
//
permute_
b_nr_kr_kw_nw_kv or
permute_
b_nr_kr_waveflatten
// b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten
constexpr
index_t
Kv
=
Alignment
;
constexpr
index_t
Kv
=
Alignment
;
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
...
@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel
...
@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
tmp_1
;
return
tmp_1
;
#else
#else
//
permute_
b_nr_kr_waveflatten =
permute_
b_nr_kr_kw_nw_kv,
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
...
@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel
...
@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel
else
else
{
{
#if MERGE_2D_013425
#if MERGE_2D_013425
//
permute_
b_nr_kr_waveflatten =
permute_
b_nr_kr_kw_nw_kv
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
return
make_tile_window
(
dst_view
,
return
make_tile_window
(
dst_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
i_n
*
NPerBlock
,
i_k
*
KPerBlock
},
{
i_n
*
NPerBlock
,
i_k
*
KPerBlock
},
get_dst_dist
());
get_dst_dist
());
#else
#else
//
permute_
b_nr_kr_waveflatten =
permute_
b_nr_kr_kw_nw_kv
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
...
...
example/ck_tile/06_permute/permute.cpp
View file @
49c39b51
...
@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
{
if
(
arg_parser
.
get_str
(
"perm"
)
==
std
::
string
(
"0,1,3,4,2,5"
))
if
(
arg_parser
.
get_str
(
"perm"
)
==
std
::
string
(
"0,1,3,4,2,5"
))
{
{
//
permute_
b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
// b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits
t
;
matrix_core_swizzle_traits
t
;
t
.
data_type
=
data_type
;
t
.
data_type
=
data_type
;
t
.
permute
=
arg_parser
.
get_str
(
"perm"
);
t
.
permute
=
arg_parser
.
get_str
(
"perm"
);
...
...
example/ck_tile/15_fused_moe/fused_moegemm.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
// this is only a convenient structure for creating an example
// this is not part of the host API
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
;
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
;
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
GDataType
=
ck_tile
::
bf16_t
;
using
DDataType
=
ck_tile
::
bf16_t
;
using
AccDataType
=
float
;
using
ODataType
=
ck_tile
::
bf16_t
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
W0ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
W1ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
IndexDataType
=
ck_tile
::
index_t
;
};
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
int8_t
,
ck_tile
::
int8_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
;
{
using
ADataType
=
ck_tile
::
int8_t
;
using
GDataType
=
ck_tile
::
int8_t
;
using
DDataType
=
ck_tile
::
int8_t
;
using
AccDataType
=
int32_t
;
using
ODataType
=
ck_tile
::
bf16_t
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
W0ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
W1ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
IndexDataType
=
ck_tile
::
index_t
;
};
// runtime args
struct
fused_moegemm_args
:
public
ck_tile
::
Layernorm2dFwdHostArgs
{
};
// This is the public API, will be generated by script
struct
fused_moegemm_traits
{
std
::
string
prec_i
;
// input precision
std
::
string
prec_w
;
// weight precision
std
::
string
prec_o
;
// output precision
std
::
string
prec_st
;
// token scale data type
std
::
string
prec_sw
;
// weight scale data type
std
::
string
prec_sq
;
// smooth quant scale
std
::
string
prec_kw
;
// topk-weight data type
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
float
fused_moegemm
(
fused_moegemm_traits
,
fused_moegemm_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/15_fused_moe/main.cpp
0 → 100644
View file @
49c39b51
#include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp"
#include <algorithm>
#include <cstring>
// different threshold for different dtype
template
<
typename
DataType
>
auto
get_elimit
()
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
bf16_t
>
()
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
// mfma_type, 0:32x32, 1:16x16
template
<
typename
H
>
auto
shuffle_moe_weight
(
const
H
&
t
,
std
::
string
mfma_dtype
,
int
mfma_type
=
0
)
{
static_assert
(
t
.
get_lengths
().
size
()
==
3
);
int
b_
=
t
.
get_lengths
()[
0
];
int
n_
=
t
.
get_lengths
()[
1
];
int
k_
=
t
.
get_lengths
()[
2
];
if
((
mfma_dtype
==
"bf16"
||
mfma_dtype
==
"fp16"
)
&&
mfma_type
==
0
)
{
std
::
vector
<
ck_tile
::
index_t
>
new_lens
{
b_
,
n_
/
32
,
32
,
k_
/
16
,
2
,
8
};
}
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"t"
,
"128"
,
"num input tokens"
)
.
insert
(
"e"
,
"32"
,
"num of experts"
)
.
insert
(
"k"
,
"5"
,
"topk"
)
.
insert
(
"h"
,
"8192"
,
"hidden_size of this model"
)
.
insert
(
"i"
,
"8192"
,
"intermediate_size between 2 gemms of FFN"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to hidden_size"
)
.
insert
(
"bm"
,
"32"
,
"blocking factor for sorted tokens"
)
.
insert
(
"tp"
,
"8"
,
"tensor parallel size"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec_i"
,
"bf16"
,
"input precision"
)
.
insert
(
"prec_w"
,
"bf16"
,
"weight precision"
)
.
insert
(
"prec_o"
,
"bf16"
,
"output precision"
)
.
insert
(
"prec_st"
,
"auto"
,
"token scale data type. auto will set to fp32"
)
.
insert
(
"prec_sw"
,
"auto"
,
"weight scale data type. auto will set to fp32"
)
.
insert
(
"prec_sq"
,
"auto"
,
"(dynamic) smooth quant data type. auto will set to fp32"
)
.
insert
(
"prec_kw"
,
"auto"
,
"topk-weight data type. auto will set to fp32"
)
.
insert
(
"fquant"
,
"0"
,
"fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant"
)
.
insert
(
"gonly"
,
"0"
,
"w0(gate/up) style, 0:gate+up will double interm size, 1:only gate"
)
.
insert
(
"balance"
,
"1"
,
"if set to 1, will try balance the expert in topk-ids(convenient for testing)"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type, SQ:smooth-quant-type, KW:topk-weight-type
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
tokens
=
arg_parser
.
get_int
(
"t"
);
ck_tile
::
index_t
experts
=
arg_parser
.
get_int
(
"e"
);
ck_tile
::
index_t
topk
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
hidden_size
=
arg_parser
.
get_int
(
"h"
);
ck_tile
::
index_t
intermediate_size
=
arg_parser
.
get_int
(
"i"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
block_m
=
arg_parser
.
get_int
(
"bm"
);
if
(
stride
<
0
)
stride
=
hidden_size
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
gonly
=
arg_parser
.
get_int
(
"gonly"
);
int
balance
=
arg_parser
.
get_int
(
"balance"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
ck_tile
::
index_t
shared_intermediate_size
=
intermediate_size
*
(
gonly
?
1
:
2
)
/
tp
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
GDataType
=
typename
TypeConfig
::
GDataType
;
using
DDataType
=
typename
TypeConfig
::
DDataType
;
using
AccDataType
=
typename
TypeConfig
::
AccDataType
;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
using
AScaleDataType
=
typename
TypeConfig
::
AScaleDataType
;
using
W0ScaleDataType
=
typename
TypeConfig
::
W0ScaleDataType
;
using
W1ScaleDataType
=
typename
TypeConfig
::
W1ScaleDataType
;
using
YSmoothScaleDataType
=
typename
TypeConfig
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
TypeConfig
::
TopkWeightDataType
;
using
IndexDataType
=
typename
TypeConfig
::
IndexDataType
;
// host verify
ck_tile
::
HostTensor
<
ADataType
>
a_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
ADataType
>
g_host
({
e
,
shared_intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
ADataType
>
d_host
({
e
,
intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
XResidualDataType
>
x_residual_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
m
});
ck_tile
::
HostTensor
<
YScaleDataType
>
y_scale_host_ref
({
m
});
ck_tile
::
HostTensor
<
YScaleDataType
>
y_scale_host_dev
({
m
});
ck_tile
::
HostTensor
<
XScaleDataType
>
x_scale_host
({
n
});
ck_tile
::
HostTensor
<
XScaleDataType
>
x_scale_host_dev
({
n
});
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
.5
f
,
.5
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
XResidualDataType
>
{
-
.5
f
,
.5
f
}(
x_residual_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
-
1.
f
,
1.
f
}(
x_scale_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
.5
f
,
.5
f
}(
beta_host
);
ck_tile
::
DeviceMem
x_buf
(
a_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_scale_buf
(
y_scale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_scale_buf
(
x_scale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_residual_buf
(
x_residual_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_residual_buf
(
y_residual_host
.
get_element_space_size_in_bytes
());
x_buf
.
ToDevice
(
a_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
beta_buf
.
ToDevice
(
beta_host
.
data
());
x_residual_buf
.
ToDevice
(
x_residual_host
.
data
());
x_scale_buf
.
ToDevice
(
x_scale_host
.
data
());
auto
prec_str
=
[
&
]()
{
auto
base_str
=
prec_i
;
if
(
prec_i
!=
prec_o
)
{
base_str
+=
"|"
+
prec_o
;
}
if
(
fused_quant
==
1
)
{
base_str
+=
std
::
string
(
"("
)
+
prec_sy
+
")"
;
}
return
base_str
;
}();
std
::
cout
<<
"["
<<
prec_str
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
layernorm2d_fwd_traits
traits
{
prec_i
,
prec_o
,
prec_sx
,
prec_sy
,
SaveMeanVar
,
fused_add
,
fused_quant
};
layernorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
fused_add
!=
0
?
x_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
x_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
gamma_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
fused_add
==
1
?
y_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
y_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
nullptr
,
// p_mean, unsupported yet
nullptr
,
// p_invStd, unsupported yet
epsilon
,
m
,
n
,
stride
};
float
ave_time
=
layernorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
{
std
::
cout
<<
" not supported!"
<<
std
::
endl
<<
std
::
flush
;
return
false
;
}
std
::
size_t
num_byte
=
sizeof
(
ADataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
sizeof
(
BetaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
bool
pass
=
true
;
if
(
do_validation
)
{
// reference
if
(
fused_add
!=
0
)
{
// fused pre_add/pre_add_store
// TODO we accumulate directly to a_host for simplcity here...
std
::
transform
(
a_host
.
mData
.
cbegin
(),
a_host
.
mData
.
cend
(),
x_residual_host
.
mData
.
cbegin
(),
a_host
.
mData
.
begin
(),
[](
auto
x_
,
auto
r_
)
{
auto
o_
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_
)
+
ck_tile
::
type_convert
<
ComputeDataType
>
(
r_
);
return
ck_tile
::
type_convert
<
ADataType
>
(
o_
);
});
}
ck_tile
::
reference_layernorm2d_fwd
<
ADataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
a_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
);
if
(
fused_quant
!=
0
)
{
auto
dquant_functor
=
[
&
](
int
m_
,
auto
&
o_
,
auto
&
acc_
)
{
int
N_
=
acc_
.
mDesc
.
get_lengths
()[
1
];
if
(
fused_quant
==
1
)
{
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
// input smooth outlier
acc_
(
m_
,
n_
)
=
acc_
(
m_
,
n_
)
*
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_scale_host
(
n_
));
}
}
ComputeDataType
absmax
=
static_cast
<
ComputeDataType
>
(
0
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
const
auto
a
=
ck_tile
::
abs
(
acc_
(
m_
,
n_
));
absmax
=
a
>
absmax
?
a
:
absmax
;
}
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType
y_scale
=
absmax
/
static_cast
<
ComputeDataType
>
(
127.0
);
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
YScaleDataType
>
(
y_scale
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
o_
(
m_
,
n_
)
=
ck_tile
::
type_convert
<
YDataType
>
(
acc_
(
m_
,
n_
)
/
y_scale
);
}
};
ck_tile
::
reference_layernorm2d_fwd
<
ADataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
a_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
,
dquant_functor
);
}
else
{
ck_tile
::
reference_layernorm2d_fwd
<
ADataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
a_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
);
}
y_buf
.
FromDevice
(
y_host_dev
.
data
());
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host_dev
({
m
,
n
},
{
stride
,
1
});
if
(
fused_add
==
1
)
{
y_residual_buf
.
FromDevice
(
y_residual_host_dev
.
data
());
}
auto
[
rtol
,
atol
]
=
get_elimit
<
InDataType
>
();
if
(
stride
==
n
)
{
pass
=
ck_tile
::
check_err
(
y_host_dev
,
y_host_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
if
(
fused_add
==
1
)
{
pass
&=
ck_tile
::
check_err
(
y_residual_host_dev
,
a_host
,
std
::
string
(
"ADD Error: Incorrect results!"
),
rtol
,
atol
);
}
}
else
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
{
std
::
vector
<
YDataType
>
y_host_dev_row
(
y_host_dev
.
begin
()
+
i_r
*
stride
,
y_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
YDataType
>
y_host_ref_row
(
y_host_ref
.
begin
()
+
i_r
*
stride
,
y_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
y_host_dev_row
,
y_host_ref_row
,
std
::
string
(
"OUT["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
atol
);
if
(
fused_add
==
1
)
{
std
::
vector
<
YResidualDataType
>
y_residual_host_dev_row
(
y_residual_host_dev
.
begin
()
+
i_r
*
stride
,
y_residual_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
YResidualDataType
>
y_residual_host_ref_row
(
a_host
.
begin
()
+
i_r
*
stride
,
a_host
.
begin
()
+
i_r
*
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
y_residual_host_dev_row
,
y_residual_host_ref_row
,
std
::
string
(
"ADD["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
atol
);
}
}
}
if
(
fused_quant
==
1
)
{
y_scale_buf
.
FromDevice
(
y_scale_host_dev
.
data
());
pass
&=
ck_tile
::
check_err
(
y_scale_host_dev
,
y_scale_host_ref
,
std
::
string
(
"SCALE Error: Incorrect results!"
),
rtol
,
atol
);
}
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
return
pass
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_sx
=
arg_parser
.
get_str
(
"prec_sx"
);
std
::
string
prec_sy
=
arg_parser
.
get_str
(
"prec_sy"
);
if
(
prec_o
==
"auto"
)
{
prec_o
=
prec_i
;
}
if
(
prec_sx
==
"auto"
)
{
prec_sx
=
"fp32"
;
}
if
(
prec_sy
==
"auto"
)
{
prec_sy
=
"fp32"
;
}
int
save_mv
=
arg_parser
.
get_int
(
"save_mv"
);
// no dynamic quant case
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
// dynamic quant case, only in inference
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"int8"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
int8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"int8"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
49c39b51
...
@@ -53,6 +53,11 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
...
@@ -53,6 +53,11 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
// clang-format on
// clang-format on
}
// namespace impl
}
// namespace impl
// TODO: this is hot-tmp fix to unblock user case. Need refactor into template arg
#ifndef CK_TILE_BUFFER_LOAD_AGPR
#define CK_TILE_BUFFER_LOAD_AGPR 0
#endif
// TODO: glc/slc/...
// TODO: glc/slc/...
template
<
index_t
bytes
,
bool
pre_nop
=
false
>
template
<
index_t
bytes
,
bool
pre_nop
=
false
>
struct
buffer_load
;
struct
buffer_load
;
...
@@ -74,6 +79,19 @@ struct buffer_load<16, pre_nop>
...
@@ -74,6 +79,19 @@ struct buffer_load<16, pre_nop>
{
{
static_assert
(
sizeof
(
T
)
==
16
);
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
16
,
T
>::
payload_t
;
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
16
,
T
>::
payload_t
;
#if CK_TILE_BUFFER_LOAD_AGPR
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
"=a"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
"=a"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
#else
if
constexpr
(
pre_nop
)
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
...
@@ -85,6 +103,7 @@ struct buffer_load<16, pre_nop>
...
@@ -85,6 +103,7 @@ struct buffer_load<16, pre_nop>
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
:
"memory"
);
#endif
}
}
};
};
...
@@ -621,6 +640,60 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
...
@@ -621,6 +640,60 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
}
template
<
typename
scalar_type
,
index_t
N
,
bool
pre_nop
=
false
>
struct
buffer_atomic_add_if
;
template
<
bool
pre_nop
>
struct
buffer_atomic_add_if
<
bf16_t
,
2
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
/*s_offset*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3
\n
"
"s_mov_b64 exec %5"
:
:
"v"
(
v_offset
),
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"s"
(
res
.
xy
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
:
"memory"
);
}
};
template
<
typename
scalar_type
,
index_t
N
,
bool
pre_nop
=
false
>
struct
buffer_atomic_add
;
template
<
bool
pre_nop
>
struct
buffer_atomic_add
<
bf16_t
,
2
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
/*s_offset*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag = 1*/
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
asm
volatile
(
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
:
:
"v"
(
v_offset
),
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"s"
(
res
.
xy
),
"n"
(
i_offset
)
:
"memory"
);
}
};
namespace
impl
{
namespace
impl
{
// below type indicate the data type used for buffer load inline asm
// below type indicate the data type used for buffer load inline asm
// clang-format off
// clang-format off
...
@@ -2378,6 +2451,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
...
@@ -2378,6 +2451,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
#endif
#endif
}
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_buffer_atomic_add_raw
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
index_t
dst_linear_element_offset
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
,
bool_constant
<
pre_nop
>
=
{})
{
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
*
sizeof
(
T
));
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
index_t
dst_linear_addr_offset
=
dst_linear_element_offset
*
sizeof
(
T
);
if
constexpr
(
oob_conditional_check
)
{
buffer_atomic_add_if
<
T
,
N
,
pre_nop
>
{}(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
,
dst_linear_addr_offset
,
dst_thread_element_valid
);
}
else
{
buffer_atomic_add
<
T
,
N
,
pre_nop
>
{}(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
,
dst_linear_addr_offset
,
1
);
}
}
// buffer_atomic_max requires:
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// 2) p_dst_wave must be a wavewise pointer.
...
...
include/ck_tile/core/tensor/buffer_view.hpp
View file @
49c39b51
...
@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
...
@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
template
<
memory_operation_enum
Op
,
typename
X
,
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
{
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
this
->
template
set
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
{
{
this
->
template
atomic_add
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
this
->
template
atomic_add
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
{
{
this
->
template
atomic_max
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
this
->
template
atomic_max
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
}
// FIXME: remove memory_operation_enum::add
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
linear_offset
,
is_valid_element
);
auto
tmp
=
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
this
->
template
get
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
);
this
->
template
set
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
// tmp += x;
// tmp += x;
// this->template set<X>(i, is_valid_element, tmp);
// this->template set<X>(i, is_valid_element, tmp);
}
}
}
}
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update_raw
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set_raw
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
{
this
->
template
atomic_add_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
{
// this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
}
}
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
...
@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
}
}
template
<
typename
X
,
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
...
@@ -585,6 +626,57 @@ struct buffer_view<address_space_enum::global,
...
@@ -585,6 +626,57 @@ struct buffer_view<address_space_enum::global,
}
}
template
<
typename
X
,
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
atomic_add_raw
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
using
scalar_t
=
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
;
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
static_assert
(
get_address_space
()
==
address_space_enum
::
global
,
"only support global mem"
);
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
||
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
)
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
bf16_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
)
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
bf16_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
,
pre_nop
>
(
x
,
p_data_
,
i
,
linear_offset
,
is_valid_element
,
buffer_size_
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
49c39b51
...
@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
...
@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_with_static_distribution
<
BottomTensorView_
,
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
return
tile_window
.
load
(
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_linear
<
BottomTensorView_
,
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
return
tile_window
.
load
(
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
DistributedTensor_
,
template
<
typename
DistributedTensor_
,
...
@@ -51,15 +55,17 @@ template <typename DistributedTensor_,
...
@@ -51,15 +55,17 @@ template <typename DistributedTensor_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
DistributedTensor_
&
dst_tile
,
CK_TILE_DEVICE
auto
load_tile
(
DistributedTensor_
&
dst_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
return
tile_window
.
load
(
dst_tile
,
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
dst_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
/**
/**
...
@@ -76,6 +82,7 @@ template <typename T,
...
@@ -76,6 +82,7 @@ template <typename T,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
...
@@ -83,11 +90,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
...
@@ -83,11 +90,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
tile_window
.
load_raw
(
tile_window
.
load_raw
(
tile
,
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
T
,
template
<
typename
T
,
...
@@ -95,6 +103,7 @@ template <typename T,
...
@@ -95,6 +103,7 @@ template <typename T,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
...
@@ -102,11 +111,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
...
@@ -102,11 +111,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
tile_window
.
load_raw
(
tile_window
.
load_raw
(
tile
,
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
LdsTileWindow_
,
template
<
typename
LdsTileWindow_
,
...
@@ -114,6 +124,7 @@ template <typename LdsTileWindow_,
...
@@ -114,6 +124,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
CK_TILE_DEVICE
auto
...
@@ -122,11 +133,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
...
@@ -122,11 +133,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
return
tile_window
.
async_load_raw
(
return
tile_window
.
async_load_raw
(
lds_tile
,
lds_tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
LdsTileWindow_
,
template
<
typename
LdsTileWindow_
,
...
@@ -134,6 +148,7 @@ template <typename LdsTileWindow_,
...
@@ -134,6 +148,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
...
@@ -141,11 +156,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
...
@@ -141,11 +156,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
return
tile_window
.
async_load_raw
(
return
tile_window
.
async_load_raw
(
lds_tile
,
lds_tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
...
...
include/ck_tile/core/tensor/tensor_view.hpp
View file @
49c39b51
...
@@ -333,6 +333,48 @@ struct tensor_view
...
@@ -333,6 +333,48 @@ struct tensor_view
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
buf_
.
template
update_raw
<
DstInMemOp
,
X
,
oob_conditional_check
,
pre_nop
>(
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
buf_
.
template
update_raw
<
DstInMemOp
,
X
,
oob_conditional_check
,
pre_nop
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
CK_TILE_HOST_DEVICE
void
print
()
const
{
{
printf
(
"tensor_view{"
);
printf
(
"tensor_view{"
);
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
49c39b51
...
@@ -785,6 +785,73 @@ struct tile_window_with_static_distribution
...
@@ -785,6 +785,73 @@ struct tile_window_with_static_distribution
});
});
}
}
template
<
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
>
CK_TILE_DEVICE
void
update_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements_raw
<
vector_t
>(
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
// move thread's botom tensor coordiante
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// [x0', x1', ... ] ==> [offset]
// also move window-origin
// also move window-origin
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
View file @
49c39b51
...
@@ -849,6 +849,58 @@ struct tile_window_linear
...
@@ -849,6 +849,58 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE
();
WINDOW_DISPATCH_ISSUE
();
}
}
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
update_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements_raw
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
};
WINDOW_DISPATCH_ISSUE
();
}
// move thread's botom tensor coordiante
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// [x0', x1', ... ] ==> [offset]
// also move window-origin
// also move window-origin
...
...
include/ck_tile/core/tensor/update_tile.hpp
View file @
49c39b51
...
@@ -41,15 +41,64 @@ template <typename BottomTensorView_,
...
@@ -41,15 +41,64 @@ template <typename BottomTensorView_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
{
tile_window
.
update
(
dstr_tensor
);
tile_window
.
update
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
update_tile_raw
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
update_raw
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
update_tile_raw
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
update_raw
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/host/reference/reference_permute.hpp
View file @
49c39b51
...
@@ -16,7 +16,7 @@ namespace ck_tile {
...
@@ -16,7 +16,7 @@ namespace ck_tile {
*/
*/
template
<
typename
DataType
>
template
<
typename
DataType
>
CK_TILE_HOST
void
CK_TILE_HOST
void
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
HostTensor
<
DataType
>&
y
,
std
::
vector
<
index_t
>
dims
)
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
HostTensor
<
DataType
>&
y
,
std
::
vector
<
index_t
>
perm
)
{
{
const
auto
x_len
=
x
.
mDesc
.
get_lengths
();
const
auto
x_len
=
x
.
mDesc
.
get_lengths
();
const
auto
y_len
=
y
.
mDesc
.
get_lengths
();
const
auto
y_len
=
y
.
mDesc
.
get_lengths
();
...
@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
...
@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
std
::
vector
<
size_t
>
tmp
(
rank
,
0
);
std
::
vector
<
size_t
>
tmp
(
rank
,
0
);
for
(
index_t
i
=
0
;
i
<
rank
;
i
++
)
for
(
index_t
i
=
0
;
i
<
rank
;
i
++
)
{
{
tmp
[
dims
[
i
]]
=
y_coord
[
i
];
tmp
[
perm
[
i
]]
=
y_coord
[
i
];
}
}
return
tmp
;
return
tmp
;
}();
}();
...
@@ -54,4 +54,24 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
...
@@ -54,4 +54,24 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
make_ParallelTensorFunctor
(
f
,
x_elm
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
,
x_elm
)(
std
::
thread
::
hardware_concurrency
());
}
}
template
<
typename
DataType
>
CK_TILE_HOST
auto
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
std
::
vector
<
index_t
>
perm
)
{
auto
x_shape
=
x
.
get_lengths
();
ck_tile
::
index_t
rank
=
perm
.
size
();
std
::
vector
<
ck_tile
::
index_t
>
y_shape
=
[
&
]()
{
std
::
vector
<
ck_tile
::
index_t
>
tmp
(
rank
,
0
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
rank
);
i
++
)
{
tmp
[
i
]
=
x_shape
[
perm
[
i
]];
}
return
tmp
;
}();
HostTensor
<
DataType
>
y
(
y_shape
);
reference_permute
(
x
,
y
,
perm
);
return
y
;
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fused_moe.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fused_moe/kernel/fused_moe_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moe_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moe_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, top_k=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_tokens_post_padded : top_k * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_tokens_post_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_tokens_post_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
//
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
// let x be one element of above, we can get:
// tpok_row_id(token_id) = x % num_tokens(5)
// tpok_col_id(expert_Id) = x / num_tokens
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// we can get permuted_rc_ids:
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
//
//
// clang-format on
//
namespace
ck_tile
{
// m: num_tokens (or token*input-batch)
// k: intermediate_size
// n: intermediate_size used between 2 FC (TP slice this)
// e: num expert
// if doing pre-shuffle
// nr : n / Block_Nr
// kr : k / Block_Kr
// w : fattened 1d wave buffer
struct
FusedMoeGemmHostArgs
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token
const
void
*
sorted_token_ids_ptr
;
const
void
*
sorted_weight_ptr
;
const
void
*
sorted_expert_ids_ptr
;
const
void
*
num_sorted_tiles_ptr
;
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n (TP slice this)
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
// index_t top_k; // need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
// This is scatter/gather b2b group-gemm
template
<
typename
Partitioner_
,
typename
Pipeline_
,
typename
Epilogue_
>
struct
FusedMoeGemmKernel
{
using
Partitioner
=
remove_cvref_t
<
Partitioner_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
// TODO: not used
static
constexpr
index_t
kBlockSize
=
Pipeline
::
kBlockSize
;
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0);
using
BlockShape
=
typename
Pipeline
::
BlockShape
;
// this is FusedMoeGemmShape
using
ADataType
=
typename
Pipeline
::
Problem
::
ADataType
;
using
GDataType
=
typename
Pipeline
::
Problem
::
GDataType
;
using
DDataType
=
typename
Pipeline
::
Problem
::
DDataType
;
using
AccDataType
=
typename
Pipeline
::
Problem
::
AccDataType
;
using
ODataType
=
typename
Pipeline
::
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Pipeline
::
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Pipeline
::
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Pipeline
::
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Pipeline
::
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Pipeline
::
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Pipeline
::
Problem
::
IndexDataType
;
using
YDataType
=
typename
Pipeline
::
Problem
::
YDataType
;
using
Traits
=
typename
Pipeline
::
Problem
::
Traits
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
// clang-format on
}
struct
FusedMoeGemmKargs
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token
const
void
*
sorted_token_ids_ptr
;
const
void
*
sorted_weight_ptr
;
const
void
*
sorted_expert_ids_ptr
;
const
void
*
num_sorted_tiles_ptr
;
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n (TP slice this)
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
// index_t top_k; // need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
// TODO: switch karg based on
using
Kargs
=
FusedMoeGemmKargs
;
using
Hargs
=
FusedMoeGemmHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
// TODO: hargs/kargs not guranteed to be the same
return
bit_cast
<
Kargs
>
(
hargs
);
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
num_cu
,
index_t
blocks_per_cu
)
{
return
TilePartitioner
::
GridSize
(
num_cu
,
blocks_per_cu
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
max
(
Pipeline
::
GetSmemSize
(),
Epilogue
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
index_t
nr_0
=
kargs
.
intermediate_size
/
Pipeline
::
Block_Nr0
;
index_t
kr_0
=
kargs
.
hidden_size
/
Pipeline
::
Block_Kr0
;
index_t
nr_1
=
kargs
.
hidden_size
/
Pipeline
::
Block_Nr1
;
// should be same as kr_0
index_t
kr_1
=
kargs
.
intermediate_size
/
Pipeline
::
Block_Kr1
;
// should be same as nr_0
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
// note this is in unit of tile, need multiple tile size to get the index
const
auto
[
sorted_tile_id
,
hidden_tile_id
]
=
TilePartitioner
{}(
num_sorted_tiles
,
kargs
.
intermediate_size
);
if
(
sorted_tile_id
>=
num_sorted_tiles
)
return
;
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
// index along intermediate_size
index_t
hidden_idx
=
__builtin_amdgcn_readfirstlane
(
hidden_tile_id
*
BlockShape
::
Block_N0
);
index_t
hidden_idx_nr
=
__builtin_amdgcn_readfirstlane
(
hidden_tile_id
*
BlockShape
::
Block_Nr0
);
const
auto
a_coord
=
Pipeline
::
GetACoord
();
// 2d thread offset, [i_row, i_col]
const
auto
sorted_token_id
=
a_coord
[
number
<
0
>
{}]
+
sorted_tile_id
*
BlockShape
::
Block_M0
;
index_t
token_id
=
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
)[
sorted_token_id
];
auto
topk_weight
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
)[
sorted_token_id
];
const
auto
a_window
=
[
&
]()
{
// A is already pre-padded in previous kernel
const
ADataType
*
a_ptr
=
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
auto
a_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
stride_token
,
1
),
number
<
Pipeline
::
kAlignmentA
>
{},
number
<
1
>
{});
// gather is here use indexing transform
const
auto
a_gather_view_
=
transform_tensor_view
(
a_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
a_window_
=
make_tile_window
(
a_gather_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
Pipeline
::
Block_K0
>
{}),
{
0
,
0
});
return
a_window_
;
}();
// TODO: gtile using NSub to have less register pressure
const
auto
g_window
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
hidden_idx_nr
*
kr_0
*
BlockShape
::
Block_W0
;
const
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
Pipeline
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
Pipeline
::
Block_W0
>
{},
1
),
number
<
Pipeline
::
kAlignmentG
>
{},
number
<
1
>
{});
const
auto
g_view_1_
=
pad_tensor_view
(
g_view_
,
make_tuple
(
number
<
Pipeline
::
Block_Nr0
>
{},
number
<
Pipeline
::
Block_Kr0
>
{},
number
<
Pipeline
::
Block_W0
>
{}),
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
const
auto
g_window_
=
make_tile_window
(
g_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
Pipeline
::
Block_Kr0
>
{},
number
<
Pipeline
::
Block_W0
>
{}),
{
0
,
0
,
0
});
return
g_window_
;
}();
const
auto
d_window
=
[
&
]()
{
const
DDataType
*
d_ptr
=
[
&
]()
{
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
hidden_idx_nr
*
BlockShape
::
Block_W1
;
// note hidden_idx_nr is along the gemm-k dim of 2nd gemm
}();
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
Pipeline
::
Block_W1
),
make_tuple
(
kr_1
*
Pipeline
::
Block_W1
,
Pipeline
::
Block_W1
,
1
),
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_view_1_
=
pad_tensor_view
(
d_view_
,
make_tuple
(
number
<
Pipeline
::
kBlockNr_1
>
{},
number
<
Pipeline
::
kBlockKr_1
>
{},
number
<
Pipeline
::
Block_W1
>
{}),
sequence
<
PadHiddenSize
,
PadIntermediateSize
,
0
>
{});
const
auto
d_window_
=
make_tile_window
(
d_view_1_
,
make_tuple
(
number
<
Pipeline
::
kBlockNr_1
>
{},
number
<
Pipeline
::
kBlockKr_1
>
{},
number
<
Pipeline
::
Block_W1
>
{}),
{
0
,
0
,
0
});
return
d_window_
;
}();
auto
o_window
=
[
&
]()
{
const
ODataType
*
o_ptr
=
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
);
const
auto
o_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
o_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
stride_token
,
1
),
number
<
Pipeline
::
kAlignmentO
>
{},
number
<
1
>
{});
// gather is here
const
auto
o_scatter_view_
=
transform_tensor_view
(
o_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
o_window_
=
make_tile_window
(
o_scatter_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M1
>
{},
number
<
Pipeline
::
Block_N1
>
{}),
{
0
,
0
});
return
o_window_
;
}();
// do compute yeah
Pipeline
{}(
a_window
,
g_window
,
d_window
,
o_window
,
topk_weight
,
smem
,
kargs
.
hidden_size
,
kargs
.
intermediate_size
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
tensors:
1. act (A): input feature map
2. gate (G): B matrix for first gemm, output will do activation(Silu)
3. up (U): B matrix for first gemm
4. down (D): B matrix for second gemm
N1
/ \
+----------+ |
| Down | |
x----------x |
hidden hidden K1 | | |
N0 N0 x----------x |
| +------x-----x------+------x-----x------+ | | |
dim | | Gate | | | Up | | | | | |
contiguous | | | | | | | | | | |
| | | | | | | | | | |
v +------x-----x------+------x-----x------+ +----------+ V
K0 | | | | | contiguous
/ \ v v v v |
+---------+ +------x-----x------+------x-----x------+ |
M0 | A | | | | | | | | |
+---------+ +------x-----x------+------x-----x------+ |
----------> | | |
contiguous | V V
| x-----x +----------+
+------------> M1 | Y | ---------> | Out(O) |
ACT x-----x +----------+
K1 = N0 dim
* Note: Act could be Gelu/Silu/...
* Note: some model does not have Up
*/
template
<
typename
BlockTile_0_
,
typename
WarpPerBlock_0_
,
typename
WarpTile_0_
,
typename
BlockTile_1_
,
typename
WarpPerBlock_1_
,
typename
WarpTile_1_
>
struct
FusedMoeGemmShape
{
using
BlockTile_0
=
remove_cvref_t
<
BlockTile_0_
>
;
using
WarpPerBlock_0
=
remove_cvref_t
<
WarpPerBlock_0_
>
;
using
WarpTile_0
=
remove_cvref_t
<
WarpTile_0_
>
;
using
BlockTile_1
=
remove_cvref_t
<
BlockTile_1_
>
;
using
WarpPerBlock_1
=
remove_cvref_t
<
WarpPerBlock_1_
>
;
using
WarpTile_1
=
remove_cvref_t
<
WarpTile_1_
>
;
static
constexpr
index_t
NumWarps
=
reduce_on_sequence
(
WarpPerBlock_0
{},
multiplies
{},
number
<
1
>
{});
static_assert
(
NumWarps
==
reduce_on_sequence
(
WarpPerBlock_1
{},
multiplies
{},
number
<
1
>
{}));
static
constexpr
index_t
Block_M0
=
BlockTile_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N0
=
BlockTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_K0
=
BlockTile_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_M0
=
WarpPerBlock_0
::
at
(
numner
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N0
=
WarpPerBlock_0
::
at
(
numner
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_K0
=
WarpPerBlock_0
::
at
(
numner
<
2
>
{});
static
constexpr
index_t
Warp_M0
=
WarpTile_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N0
=
WarpTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_K0
=
WarpTile_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
ThreadPerBlock_M0
=
Warp_M0
*
WarpPerBlock_M0
;
static
constexpr
index_t
ThreadPerBlock_N0
=
Warp_N0
*
WarpPerBlock_N0
;
static
constexpr
index_t
ThreadPerBlock_K0
=
Warp_K0
*
WarpPerBlock_K0
;
static_assert
(
Block_M0
%
ThreadPerBlock_M0
==
0
);
static_assert
(
Block_N0
%
ThreadPerBlock_N0
==
0
);
static_assert
(
Block_K0
%
ThreadPerBlock_K0
==
0
);
static
constexpr
index_t
Repeat_M0
=
Block_M0
/
ThreadPerBlock_M0
;
static
constexpr
index_t
Repeat_N0
=
Block_N0
/
ThreadPerBlock_N0
;
static
constexpr
index_t
Repeat_K0
=
Block_K0
/
ThreadPerBlock_K0
;
static
constexpr
index_t
Block_M1
=
BlockTile_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N1
=
BlockTile_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_K1
=
BlockTile_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_M1
=
WarpTile_1
::
at
(
numner
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N1
=
WarpTile_1
::
at
(
numner
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_K1
=
WarpTile_1
::
at
(
numner
<
2
>
{});
static
constexpr
index_t
Warp_M1
=
WarpPerBlock_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N1
=
WarpPerBlock_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_K1
=
WarpPerBlock_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
ThreadPerBlock_M1
=
Warp_M1
*
WarpPerBlock_M1
;
static
constexpr
index_t
ThreadPerBlock_N1
=
Warp_N1
*
WarpPerBlock_N1
;
static
constexpr
index_t
ThreadPerBlock_K1
=
Warp_K1
*
WarpPerBlock_K1
;
static_assert
(
Block_M1
%
ThreadPerBlock_M1
==
0
);
static_assert
(
Block_N1
%
ThreadPerBlock_N1
==
0
);
static_assert
(
Block_K1
%
ThreadPerBlock_K1
==
0
);
static
constexpr
index_t
Repeat_M1
=
Block_M1
/
ThreadPerBlock_M1
;
static
constexpr
index_t
Repeat_N1
=
Block_N1
/
ThreadPerBlock_N1
;
static
constexpr
index_t
Repeat_K1
=
Block_K1
/
ThreadPerBlock_K1
;
static
constexpr
index_t
BlockSize
=
warpSize
*
NumWarps
;
// some assert
static_assert
(
Block_M0
==
Block_M1
);
static_assert
(
Block_N0
==
Block_K1
||
(
Block_N0
/
2
)
==
Block_K1
);
// Gate Only or Gate+Up
// pre-shuffle tile size compute (assume only for B matrix)
// we flatten the each wave tile to a 1d linear tensor(at model loading time)
// e.g. originally we have Block_N*Block_K tile size, after pre-shuffle
// we can have Block_Nr*Block_Kr*Block_W, where Block_W is Warp_N*Warp_K,
// and Block_Nr=Block_N/Warp_N, Block_Kr=Block_K/Warp_K
static
constexpr
index_t
Block_W0
=
Warp_N0
*
Warp_K0
;
static
constexpr
index_t
Block_Nr0
=
Block_N0
/
Warp_N0
;
static
constexpr
index_t
Block_Kr0
=
Block_K0
/
Warp_K0
;
static
constexpr
index_t
Block_W1
=
Warp_N1
*
Warp_K1
;
static
constexpr
index_t
Block_Nr1
=
Block_N1
/
Warp_N1
;
static
constexpr
index_t
Block_Kr1
=
Block_K1
/
Warp_K1
;
static_assert
(
Block_W0
==
Block_W1
);
static_assert
(
Block_Nr0
==
Block_Kr1
);
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
template
<
typename
BlockShape_
>
struct
FusedMoeGemmTilePartitioner_Linear
{
// FusedMoeGemmShape
using
BlockShape
=
ck_tile
::
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
const
char
*
name
=
"eh"
;
// expert x hidden
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*num_sorted_tiles*/
,
ck_tile
::
index_t
/*hidden_size*/
))
{
index_t
i_n
=
blockIdx
.
x
;
index_t
i_m
=
blockIdx
.
y
;
return
ck_tile
::
make_tuple
(
i_m
,
i_n
);
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
max_tokens
,
index_t
hidden_size
)
{
// TODO: this may need tuning
index_t
ms
=
ck_tile
::
integer_divide_ceil
(
max_tokens
,
BlockShape
::
Block_M0
);
index_t
ns
=
ck_tile
::
integer_divide_ceil
(
hidden_size
,
BlockShape
::
Block_N0
);
return
dim3
(
ns
,
ms
,
1
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template
<
typename
Problem_
,
typename
Policy_
=
FusedMoeGemmPipelineFlatmmPolicy
>
struct
FusedMoeGemmPipeline_Flatmm
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
// this is FusedMoeGemmShape
using
ADataType
=
typename
Problem
::
ADataType
;
using
GDataType
=
typename
Problem
::
GDataType
;
using
DDataType
=
typename
Problem
::
DDataType
;
using
AccDataType
=
typename
Problem
::
AccDataType
;
using
ODataType
=
typename
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Problem
::
IndexDataType
;
using
YDataType
=
typename
Pipeline
::
Problem
::
YDataType
;
using
Traits
=
typename
Pipeline
::
Problem
::
Traits
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
static
constexpr
index_t
kAlignmentA
=
Policy
::
GetAlignment_A
<
Problem
>
();
static
constexpr
index_t
kAlignmentG
=
Policy
::
GetAlignment_G
<
Problem
>
();
static
constexpr
index_t
kAlignmentD
=
Policy
::
GetAlignment_D
<
Problem
>
();
static
constexpr
index_t
kAlignmentO
=
Policy
::
GetAlignment_O
<
Problem
>
();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
// minimize occupancy
return
2
;
}
}();
static
constexpr
const
char
*
name
=
"fused_moe_flatmm"
;
using
DropoutType
=
std
::
conditional_t
<
kHasDropout
,
BlockDropout
,
NullBlockDropout
>
;
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_A
()
{
return
Policy
<
Problem
>::
GetSmemSize_A
();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetACoord
()
{
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetOCoord
()
{
constexpr
auto
o_dist
=
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>();
const
auto
o_coord
=
o_dist
.
calculate_index
();
return
o_coord
;
}
template
<
typename
AWindow
,
typename
GWindow
,
typename
DWindow
,
typename
OWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
AWindow
&
a_window_
,
const
GWindow
&
g_window_
,
const
DWindow
&
d_window_
,
OWindow
&
o_window_
,
TopkWeightDataType
topk_weight
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
hidden_size
,
index_t
intermediate_size
)
{
_Pragma
(
"clang diagnostic push"
)
_Pragma
(
"clang diagnostic ignored
\"
-Wc++20-extensions
\"
"
)
constexpr
auto
NEG1
=
number
<-
1
>
{}
:
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
TRUE
=
bool_constant
<
true
>
{};
constexpr
auto
FALSE
=
bool_constant
<
false
>
{};
CK_TILE_LDS_ADDR
void
*
smem_0
=
smem
;
CK_TILE_LDS_ADDR
void
*
smem_1
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
void
*>
(
reinterpret_cast
<
CK_TILE_LDS_ADDR
char
*>
(
smem
)
+
Pipeline
::
GetSmemSize_A
());
auto
g_view
=
g_window_
.
get_bottom_tensor_view
();
auto
u_view
=
[
&
]()
{
if
constexpr
(
IsGateOnly
)
{
return
g_view
;
}
else
{
index_t
nr_0
=
kargs
.
intermediate_size
/
BlockShape
::
Block_Nr0
;
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Block_Kr0
;
const
GDataType
*
g_ptr
=
g_window_
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
const
GDataType
*
u_ptr
=
g_ptr
+
(
nr_0
/
2
)
*
kr_0
*
number
<
BlockShape
::
Block_W0
>
{};
const
auto
u_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
u_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
kAlignmentG
>
{},
number
<
1
>
{});
const
auto
u_view_1_
=
pad_tensor_view
(
u_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
return
u_view_1_
;
}
}();
auto
a_win
=
make_tile_window_linear
(
a_window_
,
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>());
auto
g_win
=
make_tile_window_linear
(
g_window_
,
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>());
auto
d_win
=
make_tile_window_linear
(
d_window_
,
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>());
auto
o_win
=
make_tile_window_linear
(
o_window_
,
Policy
::
template
MakeGlobalTileDistribution_O
<
Problem
>());
using
g_thread_type
=
decltype
(
load_tile
(
g_win
));
using
u_thread_type
=
decltype
(
load_tile
(
u_win
));
using
d_thread_type
=
decltype
(
load_tile
(
d_win
));
// issues_warps_lanes
auto
a_sst_win0
=
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
auto
a_sst_win1
=
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
// m*k
auto
a_sld_win0
=
[
&
]()
{
constexpr
auto
a_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
BlockShape
::
Repeat_M0
,
BlockShape
::
WarpPerBlock_M0
>
,
sequence
<
BlockShape
::
Repeat_K0
>>
,
tuple
<
sequence
<
1
>>
,
tuple
<
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_outer_dstr_enc
,
typename
WG
::
AWarpDstrEncoding
{});
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
},
a_block_dstr_encode
);
}();
// m*k
auto
a_sld_win1
=
[
&
]()
{
constexpr
auto
a_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
BlockShape
::
Repeat_M0
,
BlockShape
::
WarpPerBlock_M0
>
,
sequence
<
BlockShape
::
Repeat_K0
>>
,
tuple
<
sequence
<
1
>>
,
tuple
<
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_outer_dstr_enc
,
typename
WG
::
AWarpDstrEncoding
{});
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
},
a_block_dstr_encode
);
}();
auto
bridge_sst_win
=
[
&
]()
{
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem
,
Policy
::
template
MakeBridgeLdsStoreDesc
<
Problem
>()),
Policy
::
template
MakeBridgeLdsStoreDesc
<
Problem
>().
get_lengths
(),
{
0
,
0
});
};
auto
bridge_sld_win
=
[
&
]()
{
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem
,
Policy
::
template
MakeBridgeLdsLoadDesc
<
Problem
>()),
Policy
::
template
MakeBridgeLdsLoadDesc
<
Problem
>().
get_lengths
(),
{
0
,
0
},
Policy
::
tepmlate
MakeYTileDistribution
<
Problem
>
());
};
// also OK with C array, 2 register buffer
statically_indexed_array
<
g_thread_type
,
2
>
gs
;
using
WarpGemm0
=
Policy
::
GetWarpGemm0
<
Problem
>
();
using
WarpGemm1
=
Policy
::
GetWarpGemm1
<
Problem
>
();
auto
warp_gemm_0
=
WarpGemm0
{};
auto
warp_gemm_1
=
WarpGemm1
{};
constexpr
auto
issues_a
=
number
<
a_win
.
get_num_of_access
()
>
{};
constexpr
auto
issues_g
=
number
<
g_win
.
get_num_of_access
()
>
{};
constexpr
auto
issues_d
=
number
<
d_win
.
get_num_of_access
()
>
{};
constexpr
auto
issues_o
=
number
<
o_win
.
get_num_of_access
()
>
{};
constexpr
auto
issues_gemm0
=
number
<
BlockShape
::
Repeat_M0
*
BlockShape
::
Repeat_N0
*
BlockShape
::
Repeat_K0
>
{};
constexpr
auto
issues_gemm1
=
number
<
BlockShape
::
Repeat_M1
*
BlockShape
::
Repeat_N1
*
BlockShape
::
Repeat_K1
>
{};
constexpr
auto
issues_sld_a
=
number
<
a_sld_win0
.
get_num_of_access
()
>
{};
const
index_t
num_blocks_k0
=
(
hidden_size
+
Problem
::
Block_K0
-
1
)
/
Problem
::
Block_K0
;
const
index_t
num_blocks_n1
=
(
hidden_size
+
Problem
::
Block_N1
-
1
)
/
Problem
::
Block_N1
;
using
a_thread_type
=
decltype
(
load_tile
(
a_sld_win0
));
statically_indexed_array
<
a_thread_type
,
2
>
as
;
auto
gld_a
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
&
a_store_
,
auto
i_access
,
PreNop
=
{})
{
async_load_tile_raw
(
a_store_
,
a_win
,
i_access
,
PreNop
{});
};
auto
move_a
=
[
&
]()
{
move_tile_window
(
a_win
,
{
number
<
0
>
{},
number
<
BlockShape
::
Block_K0
>
{}});
};
auto
sld_a
=
[
&
](
auto
&
a_
,
auto
&
win_
,
auto
i_access
)
{
load_tile_raw
(
a_
,
win_
,
i_access
);
};
auto
gld_g
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
&
g_
,
auto
i_access
,
PreNop
=
{})
{
if
constexpr
(
IsGateOnly
)
{
// TODO: hack!
if
constexpr
(
i_access
.
value
==
0
)
{
g_win
.
bottom_tensor_view_
=
g_view
;
}
else
if
constexpr
(
i_access
.
value
==
issues_g
/
2
)
{
g_win
.
bottom_tensor_view_
=
u_view
;
}
}
load_tile_raw
(
g_
,
g_win
,
i_access
,
FALSE
,
PreNop
{});
};
auto
move_g
=
[
&
]()
{
move_tile_window
(
g_win
,
{
number
<
0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
0
>
{}});
}
statically_indexed_array
<
d_thread_type
,
2
>
ds
;
auto
gld_d
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
&
d_
,
auto
i_access
,
PreNop
=
{})
{
load_tile_raw
(
d_
,
d_win
,
i_access
,
FALSE
,
PreNop
{});
};
auto
move_d
=
[
&
]()
{
// d move along gemm-n
move_tile_window
(
d_win
,
{
number
<
BlockShape
::
Block_N1
>
{},
number
<
0
>
{}});
};
auto
atomic_add_o
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
&
o_
,
auto
i_access
,
PreNop
=
{})
{
update_tile_raw
(
o_win
,
o_
,
i_access
,
TRUE
,
PreNop
{});
}
auto
acc_0
=
MakeCBlockTile_Gemm0
<
Problem
>
();
auto
acc_1s
=
generate_tuple
([
&
](
auto
)
{
MakeCBlockTile_Gemm0
<
Problem
>
();
},
number
<
2
>
{});
// clang-format off
auto
gemm_0
=
[
&
]
<
typename
PostNop
=
bool_constant
<
false
>>
(
auto
&
t_c
,
auto
&
t_a
,
auto
&
t_b
,
auto
i_access
,
PostNop
=
{})
{
auto
warp_gemm
=
Policy
::
GetWarpGemm0
<
Problem
>
();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
warp_gemm
)
>
;
constexpr
auto
repeat_m
=
BlockShape
::
Repeat_M0
;
constexpr
auto
repeat_n
=
BlockShape
::
Repeat_N0
;
constexpr
auto
repeat_k
=
BlockShape
::
Repeat_K0
;
// loop order n->m->k
constexpr
auto
i_k
=
i_access
%
repeat_k
;
constexpr
auto
i_m
=
(
i_access
/
repeat_k
)
%
repeat_m
;
constexpr
auto
i_n
=
(
i_access
/
repeat_k
)
/
repeat_m
;
using
AWarpTensor
=
typename
WarpGemm
::
AWarpTensor
;
using
BWarpTensor
=
typename
WarpGemm
::
BWarpTensor
;
using
CWarpTensor
=
typename
WarpGemm
::
CWarpTensor
;
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
b_warp_y_index_zeros
=
uniform_sequence_gen_t
<
BWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
AWarpTensor
w_a
;
w_a
.
get_thread_buffer
()
=
t_a
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_k
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
BWarpTensor
w_b
;
w_b
.
get_thread_buffer
()
=
t_b
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_n
,
i_k
>
{},
b_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
b_warp_y_lengths
));
CWarpTensor
w_c
;
w_c
.
get_thread_buffer
()
=
t_c
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
WarpGemm
{}(
w_c
,
w_a
,
w_b
,
PostNop
{});
t_c
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
w_c
.
get_thread_buffer
());
};
// clang-format on
// clang-format off
auto
gemm_1
=
[
&
]
<
typename
PostNop
=
bool_constant
<
false
>>
(
auto
&
t_c
,
auto
&
t_a
,
auto
&
t_b
,
auto
i_access
,
PostNop
=
{})
{
auto
warp_gemm
=
Policy
::
GetWarpGemm1
<
Problem
>
();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
warp_gemm
)
>
;
constexpr
auto
repeat_m
=
BlockShape
::
Repeat_M1
;
constexpr
auto
repeat_n
=
BlockShape
::
Repeat_N1
;
constexpr
auto
repeat_k
=
BlockShape
::
Repeat_K1
;
// loop order n->m->k
constexpr
auto
i_k
=
i_access
%
repeat_k
;
constexpr
auto
i_m
=
(
i_access
/
repeat_k
)
%
repeat_m
;
constexpr
auto
i_n
=
(
i_access
/
repeat_k
)
/
repeat_m
;
using
AWarpTensor
=
typename
WarpGemm
::
AWarpTensor
;
using
BWarpTensor
=
typename
WarpGemm
::
BWarpTensor
;
using
CWarpTensor
=
typename
WarpGemm
::
CWarpTensor
;
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
b_warp_y_index_zeros
=
uniform_sequence_gen_t
<
BWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
AWarpTensor
w_a
;
w_a
.
get_thread_buffer
()
=
t_a
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_k
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
BWarpTensor
w_b
;
w_b
.
get_thread_buffer
()
=
t_b
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_n
,
i_k
>
{},
b_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
b_warp_y_lengths
));
CWarpTensor
w_c
;
w_c
.
get_thread_buffer
()
=
t_c
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
WarpGemm
{}(
w_c
,
w_a
,
w_b
,
PostNop
{});
t_c
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
w_c
.
get_thread_buffer
());
};
// clang-format on
_Pragma
(
"clang diagnostic pop"
)
// this gemm pipeline is designed with assumption that issues of buffer-load/ds_read can
// be hide under mfma. In other words, issues of mfma is >= memory this is true if we
// pre-shuffle B matrix, and A matrix is relatively small we prefer use multiple mfma
// paired with 1 buffer-load B matrix, to get max throughput of buffer_load. and by
// preshuffle, we always pack to dwordx4 load, and this will already extend to multiple
// mfma but that is already consumed inside warpgemm-impl. So indeed how many extra
// mfma(that can reuse the B matrix) only affected by M repeat.
auto
pipeline_gemm0
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
index_t
mfma_per_gld_g
=
total_loops
/
issues_g
;
// BlockShape::Repeat_M0;
constexpr
index_t
mfma_per_gld_a
=
total_loops
/
issues_a
;
constexpr
index_t
mfma_per_sld_a
=
total_loops
/
issues_sld_a
;
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_g
==
0
)
gld_g
(
gs
[
I1
],
number
<
i_issue
/
mfma_per_gld_g
>
{});
if
constexpr
(
i_issue
%
mfma_per_gld_a
==
0
)
gld_a
(
a_sst_win0
,
number
<
i_issue
/
mfma_per_gld_a
>
{});
if
constexpr
(
i_issue
%
mfma_per_sld_a
==
0
)
sld_a
(
as
[
I1
],
a_swin_1
,
number
<
i_issue
/
mfma_per_sld_a
>
{});
});
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I1
],
gs
[
I1
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_g
==
0
)
gld_g
(
gs
[
I0
],
number
<
i_issue
/
mfma_per_gld_g
>
{});
if
constexpr
(
i_issue
%
mfma_per_gld_a
==
0
)
gld_a
(
a_sst_win1
,
number
<
i_issue
/
mfma_per_gld_a
>
{});
if
constexpr
(
i_issue
%
mfma_per_sld_a
==
0
)
sld_a
(
as
[
I0
],
a_swin_0
,
number
<
i_issue
/
mfma_per_sld_a
>
{});
});
};
auto
pipeline_gemm0_tail
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
index_t
mfma_per_gld_g
=
total_loops
/
issues_g
;
// BlockShape::Repeat_M0;
constexpr
index_t
mfma_per_gld_a
=
total_loops
/
issues_a
;
constexpr
index_t
mfma_per_sld_a
=
total_loops
/
issues_sld_a
;
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_g
==
0
)
gld_g
(
gs
[
I1
],
number
<
i_issue
/
mfma_per_gld_g
>
{});
// if constexpr (i_issue % mfma_per_gld_a == 0)
// gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
if
constexpr
(
i_issue
%
mfma_per_sld_a
==
0
)
sld_a
(
as
[
I1
],
a_swin_1
,
number
<
i_issue
/
mfma_per_sld_a
>
{});
});
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I1
],
gs
[
I1
],
i_issue
,
TRUE
);
// last gemm has nop
});
};
auto
y
=
Policy
::
MakeYBlockTile
<
Problem
>
();
auto
pipeline_bridge
=
[
&
]()
{
// cast to Y data
auto
y_pre
=
cast_tile
<
YDataType
>
(
acc_0
);
store_tile
(
bridge_sst_win
,
y_pre
);
clear_tile
(
acc_1s
(
I0
));
wave_barrier
();
load_tile
(
y
,
bridge_sld_win
);
clear_tile
(
acc_1s
(
I1
));
};
// note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1)
auto
pipeline_gemm1
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
index_t
mfma_per_gld_d
=
total_loops
/
issues_d
;
// BlockShape::Repeat_M0;
constexpr
index_t
mfma_per_atm_o
=
total_loops
/
issues_o
;
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
gld_d
(
ds
[
I0
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
if
constexpr
(
i_issue
%
mfma_per_atm_o
==
0
)
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I0
]);
atomic_add_o
(
out
,
number
<
i_issue
/
mfma_per_atm_o
>
{});
}
});
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
gld_d
(
ds
[
I1
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
if
constexpr
(
i_issue
%
mfma_per_atm_o
==
0
)
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I1
]);
atomic_add_o
(
out
,
number
<
i_issue
/
mfma_per_atm_o
>
{});
}
});
};
auto
pipeline_gemm1_head
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
index_t
mfma_per_gld_d
=
total_loops
/
issues_d
;
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
gld_d
(
ds
[
I1
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
});
};
auto
pipeline_gemm1_tail
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
index_t
mfma_per_gld_d
=
total_loops
/
issues_d
;
constexpr
index_t
mfma_per_atm_o
=
total_loops
/
issues_o
;
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
gld_d
(
ds
[
I0
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
if
constexpr
(
i_issue
%
mfma_per_atm_o
==
0
)
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I0
]);
atomic_add_o
(
out
,
number
<
i_issue
/
mfma_per_atm_o
>
{});
}
});
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I1
]);
atomic_add_o
(
out
,
NEG1
);
}
};
// start of pipeline
// clang-format off
gld_a
(
a_sst_win0
,
NEG1
,
TRUE
);
gld_g
(
gs
[
I0
],
NEG1
,
TRUE
);
sld_a
(
as
[
I0
],
a_swin_0
,
NEG1
);
gld_a
(
a_sst_win1
,
NEG1
);
clear_tile
(
acc_0
);
// we manually unroll double buffer inside hot loop
const
index_t
iters_0
=
(
num_blocks_k0
-
2
)
/
2
;
index_t
i_0
=
0
;
while
(
i_0
<
iters_0
)
{
pipeline_gemm0
();
}
pipeline_gemm0_tail
();
pipeline_bridge
();
const
index_t
iters_1
=
(
num_blocks_n1
-
2
)
/
2
;
index_t
i_1
=
0
;
pipeline_gemm1_head
();
while
(
i_0
<
iters_0
)
{
pipeline_gemm1
();
}
pipeline_gemm1_tail
();
// clang-format on
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
struct
FusedMoeGemmPipelineFlatmmPolicy
{
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetAsyncCopyDwords
()
{
// TODO: always 1 dword
return
1
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_A
()
{
// using async
static
constexpr
index_t
copy_bytes
=
4
*
GetAsyncCopyDwords
();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
ADataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_G
()
{
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
GDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_D
()
{
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
DDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_O
()
{
if
constexpr
(
Problem
::
Traits
::
OAtomic
==
1
)
{
// pack fp16/bf16 atomic
static_assert
(
sizeof
(
typename
Problem
::
ODataType
)
==
2
);
return
2
;
}
else
if
constexpr
(
Problem
::
Traits
::
OAtomic
==
2
)
{
// fp32 atomic
return
1
;
}
else
{
return
16
/
sizeof
(
typename
Problem
::
ODataType
);
}
}
template
<
typename
DataType_
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPack
()
{
// TODO: this is for 3d layout
return
16
/
sizeof
(
remove_cvref_t
<
typename
Problem
::
DataType_
>
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPack_A
()
{
return
GetSmemKPack
<
typename
Problem
::
ADataType
>
();
}
#if 0
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWaveFlattenShape()
{
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
return sequence<Kw, Nw, Kv>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockTileNrKr()
{
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
return sequence<Problem::BlockShape::Block_K0 / Nw,
Problem::BlockShape::Block_K0 / (Kw * Kv)>{};
}
#endif
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_A
()
{
constexpr
auto
a_sld_desc
=
MakeLdsLoadDesc_A
<
Problem
>
();
constexpr
auto
a_sst_desc
=
MakeLdsStoreDesc_A
<
Problem
>
();
static_assert
(
a_sld_desc
.
get_element_space_size
()
==
a_sst_desc
.
get_element_space_size
());
return
a_sld_desc
.
get_element_space_size
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_Bridge
()
{
constexpr
auto
bridge_sld_desc
=
MakeBridgeLdsLoadDesc
<
Problem
>
();
constexpr
auto
bridge_sst_desc
=
MakeBridgeLdsStoreDesc
<
Problem
>
();
static_assert
(
bridge_sld_desc
.
get_element_space_size
()
==
bridge_sst_desc
.
get_element_space_size
());
return
bridge_sld_desc
.
get_element_space_size
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
constexpr
index_t
a_lds
=
GetSmemSize_A
<
Problem
>
();
constexpr
index_t
bridge_lds
=
GetSmemSize_Bridge
<
Problem
>
();
return
max
(
a_lds
,
bridge_lds
);
}
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
Alignment
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_SimpleMxK
()
{
constexpr
index_t
K_vec
=
Alignment
;
constexpr
index_t
K_rem
=
KPerBlock
/
K_vec
;
if
constexpr
(
get_warp_size
()
<
K_rem
)
{
static_assert
(
K_rem
%
get_warp_size
()
==
0
);
constexpr
index_t
K_lan
=
get_warp_size
();
// lane within same wave is along gemm-k
constexpr
index_t
K_wav
=
K_rem
/
get_warp_size
();
static_assert
(
K_wav
<=
NumWarps
,
"not not support thread has repeat along K yet"
);
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
>
,
sequence
<
K_wav
,
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
2
>>
{});
}
else
{
constexpr
index_t
K_lan
=
K_rem
;
constexpr
index_t
M_lan
=
get_warp_size
()
/
K_lan
;
constexpr
index_t
M_wav
=
NumWarps
;
static_assert
(
MPerBlock
%
(
M_lan
*
M_wav
)
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
(
M_lan
*
M_wav
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
,
M_lan
>
,
sequence
<
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
// optimized version for async, not same as simple MXK dist(pay attention!!)
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
Alignment
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_SimpleMxK_Async
()
{
constexpr
index_t
K_vec
=
Alignment
;
constexpr
index_t
K_rem
=
KPerBlock
/
K_vec
;
if
constexpr
(
get_warp_size
()
<=
K_rem
)
{
static_assert
(
K_rem
%
get_warp_size
()
==
0
);
constexpr
index_t
K_lan
=
get_warp_size
();
// lane within same wave is along gemm-k
constexpr
index_t
K_wav
=
K_rem
/
get_warp_size
();
static_assert
(
K_wav
<=
NumWarps
,
"do not support thread has repeat along K yet"
);
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
// NOTE: no swap, but hard to avoid LDS bank conflict
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
>
,
sequence
<
K_wav
,
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
2
>>
{});
}
else
{
constexpr
index_t
K_lan
=
K_rem
;
constexpr
index_t
M_lan
=
get_warp_size
()
/
K_lan
;
constexpr
index_t
M_wav
=
NumWarps
;
static_assert
(
MPerBlock
%
(
M_lan
*
M_wav
)
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
(
M_lan
*
M_wav
);
// NOTE: swapped for LDS load bank conflict free
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
// Note M_wave(num waves) is the fastest dim, different from sipmle 2d
// distribution
tuple
<
sequence
<
M_rep
,
M_lan
,
M_wav
>
,
sequence
<
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_0
()
{
if
constexpr
(
Problem
::
Traits
::
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
using
WarpGemm
=
GetWarpGemm0
<
Problem
>
{};
// assume warpgemm0/1 are the same
constexpr
index_t
NPerBlock
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockShape
::
Block_K0
;
constexpr
index_t
Kv
=
GetAlignment_G
<
{
Problem
}
>
();
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
return
sequence
<
Nr
,
Kr
,
Kw
*
Nw
*
Kv
>
{};
// 3D
}
}
#if 0
// Caution: this will require global memory pre-shuffled to follow the mfma layout
template <index_t NPerBlock,
index_t KPerBlock,
index_t WavesPerBlock_N,
index_t WavesPerBlock_K,
typename WarpGemm,
index_t Alignment,
FusedMoeGemmWeightPermuteEnum PermuteEnum =
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_MatrixCore_Swizzled()
{
static_assert(Alignment % WarpGemm::WarpGemmAttribute::Impl::kABKPerLane == 0);
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
constexpr index_t Nr_p = WavesPerBlock_N;
constexpr index_t Kr_p = WavesPerBlock_K;
constexpr index_t Nr_y = Nr / Nr_p;
constexpr index_t Kr_y = Kr / Kr_p;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>, // 0
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>,
// Nr_p, Kr_p Kw Nw
tuple<sequence<1, 2>, sequence<3, 3>>,
tuple<sequence<1, 1>, sequence<0, 1>>,
// Nr_y Kr_y Kv
sequence<1, 2, 3>,
sequence<0, 0, 2>>{});
// clang-format on
}
}
#endif
template
<
index_t
WarpPerBlock_N_
,
index_t
WarpPerBlock_K_
,
index_t
Repeat_N_
,
index_t
Repeat_K_
,
index_t
WarpSize_
,
index_t
Alignment_
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_Nr_Kr_W
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Repeat_N_
,
WarpPerBlock_N_
>
,
sequence
<
Repeat_K_
,
WarpPerBlock_K_
>
,
sequence
<
WarpSize_
,
Alignment_
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
3
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
0
>>
,
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_A
()
{
constexpr
index_t
Block_M_
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_K_
=
Problem
::
BlockShape
::
Block_K0
;
constexpr
index_t
NumWarps_
=
Problem
::
BlockShape
::
NumWarps
;
constexpr
index_t
Alignment_
=
GetAlignment_A
<
Problem
>
();
return
MakeGlobalTileDistribution_SimpleMxK_Async
<
Block_M_
,
Block_K_
,
NumWarps_
,
Alignment_
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_G
()
{
constexpr
auto
PermuteEnum
=
Problem
::
Traits
::
PermuteEnum
;
// constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2;
using
S_
=
typename
Problem
::
BlockShape
;
if
constexpr
(
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
return
MakeGlobalTileDistribution_Nr_Kr_W
<
S_
::
WarpPerBlock_N0
,
S_
::
WarpPerBlock_K0
,
S_
::
Repeat_N0
,
/// hidden_radio_0,
S_
::
Repeat_K0
,
get_warp_size
(),
GetAlignment_G
<
Problem
>
()
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_D
()
{
constexpr
auto
PermuteEnum
=
Problem
::
Traits
::
PermuteEnum
;
using
S_
=
typename
Problem
::
BlockShape
;
if
constexpr
(
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
return
MakeGlobalTileDistribution_Nr_Kr_W
<
S_
::
WarpPerBlock_N1
,
S_
::
WarpPerBlock_K1
,
S_
::
Repeat_N1
,
S_
::
Repeat_K1
,
get_warp_size
(),
GetAlignment_D
<
Problem
>
()
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsStoreDesc_A
()
{
// A async->LDS
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_K
=
Problem
::
BlockShape
::
Block_K0
;
constexpr
index_t
BlockSize
=
Problem
::
BlockShape
::
BlockSize
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
BlockShape
::
NumWarps
;
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
kVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
kPad
=
KPack
;
// pad between warps
static_assert
(
Block_K
%
kVector
==
0
);
constexpr
index_t
LanesPerK
=
Block_K
/
kVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
Block_M
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_merge_transform
(
make_tuple
(
number
<
wavesPerM
>
{},
number
<
wavesPerK
>
{})),
make_merge_transform
(
make_tuple
(
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
Block_M
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
Block_K
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_pass_through_transform
(
number
<
NumWarps
>
{}),
make_merge_transform
(
make_tuple
(
number
<
LaneGroups
>
{},
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
2
>
{},
sequence
<
1
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsLoadDesc_A
()
{
// A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS
// in real Gemm pipeline, ds_read may not follow this pattern
// (may follow that in tile_distribution)
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_K
=
Problem
::
BlockShape
::
Block_K0
;
constexpr
index_t
BlockSize
=
Problem
::
BlockShape
::
BlockSize
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
BlockShape
::
NumWarps
;
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
kVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
kPad
=
KPack
;
// pad between warps
static_assert
(
Block_K
%
kVector
==
0
);
constexpr
index_t
LanesPerK
=
Block_K
/
kVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>=
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
Block_M
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KPack
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
wavesPerM
>
{})),
make_merge_transform
(
make_tuple
(
number
<
wavesPerK
>
{},
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
Block_M
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
Block_K
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KPack
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsLoadDesc
()
{
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
kVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
kPad
=
KPack
;
// pad between warps
constexpr
auto
desc
=
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
make_tuple
(
number
<
Block_N
+
kPad
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
return
desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsStoreDesc
()
{
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
kVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
kPad
=
KPack
;
// pad between warps
constexpr
auto
desc
=
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
make_tuple
(
number
<
Block_N
+
kPad
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
return
desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm0
()
{
using
S_
=
typename
Problem
::
BlockShape
;
// A is vgpr, B is agpr. But since we transposed, so also need swap this
// TODO: this is ugly
constexpr
auto
wg_ctrl
=
WGAttrCtlEnum
::
Raw_vav
;
// TODO: ugly
if
constexpr
(
std
::
is_same_v
<
Problem
::
ADataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
Problem
::
GDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
16
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K
<
wg_ctrl
>
,
2
>>
{};
}
else
if
constexpr
(
std
::
is_same_v
<
Problem
::
ADataType
,
ck_tile
::
int8_t
>
&&
std
::
is_same_v
<
Problem
::
GDataType
,
ck_tile
::
int8_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
32
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
<
wg_ctrl
>
,
2
>>
{};
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm1
()
{
using
S_
=
typename
Problem
::
BlockShape
;
constexpr
auto
wg_ctrl
=
WGAttrCtlEnum
::
Raw_vva
;
// TODO: ugly
if
constexpr
(
std
::
is_same_v
<
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
16
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K
<
wg_ctrl
>
,
2
>>
{};
}
else
if
constexpr
(
std
::
is_same_v
<
Problem
::
YDataType
,
ck_tile
::
int8_t
>
&&
std
::
is_same_v
<
Problem
::
DDataType
,
ck_tile
::
int8_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
32
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
<
wg_ctrl
>
,
2
>>
{};
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
constexpr
auto
MakeCBlockTile_Gemm0
()
const
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm0
<
Problem
>
())
>
;
using
CDataType
=
WarpGemm
::
WarpGemm
;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S_
::
Repeat_M0
,
S_
::
WarpPerBlock_M0
>
,
sequence
<
S_
::
Repeat_N0
,
S_
::
WarpPerBlock_N0
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
constexpr
auto
MakeCBlockTile_Gemm1
()
const
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
using
CDataType
=
WarpGemm
::
CDataType
;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
sequence
<
S_
::
Repeat_N1
,
S_
::
WarpPerBlock_N1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
// this is used as A matrix for 2nd gemm
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
constexpr
auto
MakeYTileDistribution
()
const
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
using
YDataType
=
typename
Problem
::
YDataType
;
// TODO: all waves a along different N, but same M
constexpr
auto
y_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_M1
>
,
tuple
<
sequence
<
S_
::
Repeat_M1
>
,
sequence
<
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
0
>>
,
tuple
<
sequence
<
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
y_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
y_outer_dstr_enc
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
y_block_dstr
=
make_static_tile_distribution
(
y_block_dstr_encode
);
return
y_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
constexpr
auto
MakeYBlockTile
()
const
{
constexpr
auto
y_block_dstr
=
MakeYTileDistribution
<
Problem
>
();
auto
y_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
y_block_dstr
);
return
y_block_tensor
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_0
()
{
if
constexpr
(
Problem
::
Traits
::
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
using
WarpGemm
=
GetWarpGemm0
<
Problem
>
{};
// assume warpgemm0/1 are the same
constexpr
index_t
NPerBlock
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockShape
::
Block_K0
;
constexpr
index_t
Kv
=
GetAlignment_G
<
{
Problem
}
>
();
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
return
sequence
<
Nr
,
Kr
,
Kw
*
Nw
*
Kv
>
{};
// 3D
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_1
()
{
if
constexpr
(
Problem
::
Traits
::
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
using
WarpGemm
=
GetWarpGemm1
<
Problem
>
{};
// assume warpgemm0/1 are the same
constexpr
index_t
NPerBlock
=
Problem
::
BlockShape
::
kBlockN_1
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockShape
::
kBlockK_1
;
constexpr
index_t
Kv
=
GetAlignment_G
<
{
Problem
}
>
();
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
return
sequence
<
Nr
,
Kr
,
Kw
*
Nw
*
Kv
>
{};
// 3D
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// TODO: alow 2 gemm have different type
template
<
typename
ADataType_
,
typename
GDataType_
,
typename
DDataType_
,
typename
AccDataType_
,
typename
ODataType_
,
typename
AScaleDataType_
,
typename
W0ScaleDataType_
,
typename
W1ScaleDataType_
,
typename
YSmoothScaleDataType_
,
typename
TopkWeightDataType_
,
typename
IndexDataType_
,
// data type for all indexing
typename
GateActivation_
,
// = ck_tile::element_wise::Silu,
typename
BlockShape_
,
// shoule be FusedMoeGemmShape
typename
Traits_
>
struct
FusedMoeGemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
GDataType
=
remove_cvref_t
<
GDataType_
>
;
using
DDataType
=
remove_cvref_t
<
DDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
AScaleDataType
=
remove_cvref_t
<
AScaleDataType_
>
;
using
GScaleDataType
=
remove_cvref_t
<
GScaleDataType_
>
;
using
DScaleDataType
=
remove_cvref_t
<
DScaleDataType_
>
;
using
YSmoothScaleDataType
=
remove_cvref_t
<
YSmoothScaleDataType_
>
;
using
TopkWeightDataType
=
remove_cvref_t
<
TopkWeightDataType_
>
;
using
IndexDataType
=
remove_cvref_t
<
IndexDataType_
>
;
// the input for next gemm should have same time as
using
YDataType
=
ADataType
;
using
GateActivation
=
remove_cvref_t
<
GateActivation_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
}
// namespace ck_tile
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment