Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b885995c
Commit
b885995c
authored
Dec 12, 2024
by
letaoqin
Browse files
first right version
parent
40df5c8b
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
240 additions
and
113 deletions
+240
-113
example/ck_tile/17_fused_moe_general/main.cpp
example/ck_tile/17_fused_moe_general/main.cpp
+42
-7
include/ck_tile/host/fill.hpp
include/ck_tile/host/fill.hpp
+1
-1
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+1
-1
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+1
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+21
-14
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+1
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+119
-57
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+54
-33
No files found.
example/ck_tile/17_fused_moe_general/main.cpp
View file @
b885995c
...
...
@@ -87,32 +87,42 @@ template <typename IndexType>
void
output_matrix_2d
(
ck_tile
::
HostTensor
<
IndexType
>&
data
,
int
m
,
int
n
)
{
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"["
;
for
(
int
i
=
0
;
i
<
m
;
i
++
)
{
std
::
cout
<<
"
Line "
<<
i
<<
"
\t
"
;
std
::
cout
<<
"
[
"
;
for
(
int
j
=
0
;
j
<
n
;
j
++
)
{
std
::
cout
<<
ck_tile
::
type_convert
<
float
>
(
data
(
i
,
j
))
<<
"
\t
"
;
std
::
cout
<<
ck_tile
::
type_convert
<
float
>
(
data
(
i
,
j
));
if
(
j
!=
n
-
1
)
std
::
cout
<<
", "
;
}
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"],
\n
"
;
}
std
::
cout
<<
"]
\n
"
;
}
template
<
typename
IndexType
>
void
output_matrix_3d
(
ck_tile
::
HostTensor
<
IndexType
>&
data
,
int
M
,
int
N
,
int
J
)
{
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"["
;
for
(
int
m
=
0
;
m
<
M
;
m
++
)
{
std
::
cout
<<
"["
;
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
std
::
cout
<<
"
experts: "
<<
m
<<
" Line: "
<<
n
<<
"
\t
"
;
std
::
cout
<<
"
[
"
;
for
(
int
j
=
0
;
j
<
J
;
j
++
)
{
std
::
cout
<<
ck_tile
::
type_convert
<
float
>
(
data
(
m
,
n
,
j
))
<<
"
\t
"
;
std
::
cout
<<
ck_tile
::
type_convert
<
float
>
(
data
(
m
,
n
,
j
));
if
(
j
!=
j
-
1
)
std
::
cout
<<
", "
;
}
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"],
\n
"
;
}
std
::
cout
<<
"],
\n
"
;
}
std
::
cout
<<
"]
\n
"
;
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -237,6 +247,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify
ck_tile
::
HostTensor
<
ADataType
>
a_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
GDataType
>
g_host
({
experts
,
shared_intermediate_size_0
,
hidden_size
});
ck_tile
::
HostTensor
<
ODataType
>
c_host
({
tokens
,
intermediate_size
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
({
experts
,
hidden_size
,
shared_intermediate_size_1
});
ck_tile
::
HostTensor
<
ODataType
>
o_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
AScaleDataType
>
sa_host
({
tokens
});
...
...
@@ -269,6 +280,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
GDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
g_host
);
ck_tile
::
FillUniformDistribution
<
DDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
d_host
);
// ck_tile::FillConstant<ADataType>{1}(a_host);
// ck_tile::FillConstant<GDataType>{1}(g_host);
// ck_tile::FillConstant<DDataType>{1}(d_host);
ck_tile
::
FillUniformDistribution
<
AScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sa_host
);
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sd_host
);
...
...
@@ -389,6 +403,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
o_buf
(
o_host
);
ck_tile
::
DeviceMem
c_buf
(
c_host
);
c_buf
.
SetZero
();
std
::
cout
<<
"
\n
c size: "
<<
c_buf
.
GetBufferSize
()
<<
" tokens * intermediate_size: "
<<
tokens
*
intermediate_size
<<
std
::
endl
;
// manually clear output buffer for atomic
o_buf
.
SetZero
();
...
...
@@ -428,7 +446,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts
,
topk
,
stride
,
max_num_tokens_padded
};
max_num_tokens_padded
,
c_buf
.
GetDeviceBuffer
()};
float
ave_time
=
fused_moegemm
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
...
...
@@ -469,6 +488,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
gate_only
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
auto
c_dev
=
c_buf
.
ToHost
<
ADataType
>
();
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
o_dev
<<
std
::
endl
;
// std::cout << c_dev << std::endl;
// int count = 0;
// std::cout << "[";
// for(int i = 0; i < tokens; i++)
// {
// std::cout << "[";
// for(int j = 0; j < intermediate_size; j++)
// {
// std::cout << ck_tile::type_convert<float>(c_dev(count++)) << ",";
// }
// std::cout << "],\n";
// }
// std::cout << "]\n";
// o_dev.savetxt("gpu-out.txt", "float");
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
pass
&=
ck_tile
::
check_err
(
...
...
include/ck_tile/host/fill.hpp
View file @
b885995c
...
...
@@ -340,7 +340,7 @@ template <typename T>
struct
FillConstant
{
T
value_
{
0
};
FillConstant
(
float
value
)
:
value_
(
ck_tile
::
type_convert
<
T
>
(
value
)){}
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
...
...
include/ck_tile/host/host_tensor.hpp
View file @
b885995c
...
...
@@ -586,7 +586,7 @@ struct HostTensor
}
if
constexpr
(
std
::
is_same_v
<
T
,
bf16_t
>
||
std
::
is_same_v
<
T
,
fp16_t
>
)
{
os
<<
type_convert
<
float
>
(
t
.
mData
[
idx
])
<<
"
####
"
;
os
<<
type_convert
<
float
>
(
t
.
mData
[
idx
])
<<
" "
;
}
else
{
...
...
include/ck_tile/host/reference/reference_fused_moe.hpp
View file @
b885995c
...
...
@@ -137,6 +137,7 @@ void reference_fused_moe(
", 1:"
+
std
::
to_string
(
intermediate_size_1
));
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_1
;
i_n
++
)
{
//y(0, i_n) = acc_0(0, i_n);
Activation
{}(
y
(
0
,
i_n
),
acc_0
(
0
,
i_n
));
// if(i_expert == 0)
// printf("in:%d, %f\t", i_n, y(0, i_n));
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
b885995c
...
...
@@ -199,6 +199,7 @@ struct FusedMoeGemmGlKernel
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
max_num_tokens_padded
;
// size of sorted_token_ids_ptr
void
*
c_ptr
;
};
// TODO: switch karg based on
...
...
@@ -255,8 +256,6 @@ struct FusedMoeGemmGlKernel
const
auto
sorted_token_id
=
a_coord
[
number
<
0
>
{}]
+
idx_m0
;
// start block_m
// position
// index_t token_id =
// reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto
topk_weight
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
)[
sorted_token_id
];
...
...
@@ -305,19 +304,26 @@ struct FusedMoeGemmGlKernel
make_tuple
(
number
<
BlockShape
::
Block_N0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
idx_n0
,
0
});
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// for(int i = 0; i < 16; i++)
// {
// printf("in G index is %d , value is: %f\n",
// i,
// ck_tile::type_convert<float>(g_ptr[i]));
// }
// }
return
g_window_
;
}();
auto
c_window
=
[
&
]()
{
YDataType
*
c_ptr
=
reinterpret_cast
<
YDataType
*>
(
kargs
.
c_ptr
);
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
auto
c_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
intermediate_size
),
make_tuple
(
kargs
.
intermediate_size
,
1
),
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
auto
c_window_
=
make_tile_window
(
c_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_N0
>
{}),
{
0
,
0
});
return
c_window_
;
}();
const
auto
d_window
=
[
&
]()
{
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
;
...
...
@@ -371,7 +377,8 @@ struct FusedMoeGemmGlKernel
topk_weight
,
smem
,
kargs
.
hidden_size
,
kargs
.
intermediate_size
);
kargs
.
intermediate_size
,
c_window
);
}
};
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
b885995c
...
...
@@ -118,6 +118,7 @@ struct FusedMoeGemmHostArgs
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
max_num_tokens_padded
;
// size of sorted_token_ids_ptr
void
*
c_ptr
;
};
// This is scatter/gather b2b group-gemm
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
b885995c
...
...
@@ -68,16 +68,24 @@ struct FusedMoeGemmPipeline_General
static
constexpr
const
char
*
name
=
"flatmm_gl"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
A
()
{
// matrix a or tokens smem
constexpr
index_t
smem_mat_a
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_K0
*
sizeof
(
ADataType
);
return
smem_mat_a
;
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
// matrix a or tokens smem
constexpr
index_t
smem_mat_a
=
GetSmemSizeA
();
constexpr
index_t
smem_mat_d
=
BlockShape
::
Block_N0
*
BlockShape
::
Block_K0
*
sizeof
(
GDataType
);
// shuffle C matrix
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
max
(
smem_mat_a
,
smem_bridge
);
return
max
(
smem_mat_a
+
smem_mat_d
,
smem_bridge
);
// return Policy::template GetSmemSize<Problem>();
}
...
...
@@ -117,7 +125,11 @@ struct FusedMoeGemmPipeline_General
});
});
}
template
<
typename
AWindow
,
typename
GWindow
,
typename
DWindow
,
typename
OWindow
>
template
<
typename
AWindow
,
typename
GWindow
,
typename
DWindow
,
typename
OWindow
,
typename
CWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
AWindow
&
a_window_
,
const
GWindow
&
g_window_
,
const
DWindow
&
d_window_
,
...
...
@@ -125,9 +137,16 @@ struct FusedMoeGemmPipeline_General
TopkWeightDataType
topk_weight
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
hidden_size
,
index_t
intermediate_size
)
index_t
/*intermediate_size*/
,
CWindow
&
c_window_
)
{
ignore
=
topk_weight
;
ignore
=
c_window_
;
ignore
=
hidden_size
;
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
CK_TILE_LDS_ADDR
GDataType
*
smem_1
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
GDataType
*>
(
smem_0
+
GetSmemSizeA
()
/
sizeof
(
ADataType
));
auto
a_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsBlockDesc_A
<
Problem
>());
auto
a_lds_win
=
make_tile_window
(
...
...
@@ -135,6 +154,13 @@ struct FusedMoeGemmPipeline_General
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
0
,
0
});
auto
g_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsBlockDesc_G
<
Problem
>());
auto
g_lds_win
=
make_tile_window
(
g_lds_view
,
make_tuple
(
number
<
BlockShape
::
Block_N0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
0
,
0
});
auto
a_global_to_dram_window
=
make_tile_window
(
a_window_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
...
...
@@ -148,69 +174,85 @@ struct FusedMoeGemmPipeline_General
g_window_
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>());
// gemm gate
#if 0
PrintMem(g_dram_block, "G", 0);
#endif
// gemm0(gate)
constexpr
auto
gemm_0
=
Policy
::
template
GetBlockGemm0
<
Problem
>();
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
// save tokens to lds
auto
a_dram_block
=
load_tile
(
a_global_to_dram_window
);
store_tile
(
a_lds_win
,
a_dram_block
);
#if 0
PrintMem(a_dram_block,"A", 0, 1);
#endif
auto
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
// block_sync_load_raw();
// save tokens to lds
store_tile
(
a_lds_win
,
a_dram_block
);
store_tile
(
g_lds_win
,
g_dram_block
);
#if
1
PrintMem
(
g
_dram_block
,
"G
"
,
0
,
1
);
#if
0
PrintMem(
a
_dram_block,
"A
", 0);
#endif
clear_tile
(
s_acc
);
// initialize C
constexpr
index_t
kK0
=
BlockShape
::
Block_K0
;
const
index_t
k0_loops
=
ck_tile
::
integer_divide_ceil
(
hidden_size
,
kK0
);
index_t
iCounter0
=
k0_loops
-
1
;
while
(
iCounter0
>
0
)
while
(
iCounter0
>=
0
)
{
if
(
iCounter0
>
0
)
{
block_sync_lds
();
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
block_sync_lds
();
move_tile_window
(
a_global_to_dram_window
,
{
0
,
kK0
});
move_tile_window
(
g_global_to_dram_window
,
{
0
,
kK0
});
a_dram_block
=
load_tile
(
a_global_to_dram_window
);
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
}
block_sync_lds
();
gemm_0
(
s_acc
,
a_lds_win
,
g_lds_win
);
// gemm_0(s_acc, a_lds_win, g_dram_block);
block_sync_lds
();
if
(
iCounter0
>
0
)
{
store_tile
(
a_lds_win
,
a_dram_block
);
store_tile
(
g_lds_win
,
g_dram_block
);
}
iCounter0
--
;
}
// tail
{
block_sync_lds
();
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
}
// {
// block_sync_lds();
// // gemm_0(s_acc, a_lds_win, g_dram_block);
// gemm_0(s_acc, a_lds_win, g_lds_win);
// block_sync_lds();
// }
#if 0
PrintMem(s_acc);
PrintMem(s_acc
, "S", 0
);
#endif
// relu
const
auto
activation
=
ck_tile
::
element_wise
::
Gelu
{};
tile_elementwise_inout
(
activation
,
s_acc
,
s_acc
);
// const auto activation = ck_tile::element_wise::Gelu{};
// tile_elementwise_inout(activation, s_acc, s_acc);
// cast data to YDataType
auto
y_pre
=
cast_tile
<
YDataType
>
(
s_acc
);
// move sacc to LDS
#if 0
PrintMem(y_pre, "Y_pre", 0);
#endif
if
(
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
{
block_sync_lds
();
store_tile
(
c_window_
,
y_pre
);
}
// save to lds
auto
bridge_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>());
auto
bridge_slds_win
=
make_tile_window
(
bridge_lds_view
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// cast data to YDataType
auto
y_pre
=
cast_tile
<
YDataType
>
(
s_acc
);
#if 0
PrintMem(y_pre);
#endif
// save to lds
store_tile
(
bridge_slds_win
,
y_pre
);
block_sync_lds
();
...
...
@@ -225,7 +267,20 @@ struct FusedMoeGemmPipeline_General
{
0
,
0
},
Policy
::
template
MakeYTileDistribution
<
Problem
>());
auto
y
=
load_tile
(
bridge_llds_win
);
block_sync_lds
();
#if 0
PrintMem(y,"Y",0);
//PrintMem(y,"Y",32);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
for(int i = 0; i < 16; i++)
{
printf("\n smem_0[%d]: %f ", i, type_convert<float>(smem_0[i]));
}
}
//store_tile(c_window_, y);
#endif
// d data
auto
d_global_to_dram_window
=
make_tile_window
(
d_window_
.
get_bottom_tensor_view
(),
...
...
@@ -234,20 +289,20 @@ struct FusedMoeGemmPipeline_General
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>());
auto
d
=
load_tile
(
d_global_to_dram_window
);
#if 0
PrintMem(d,"D",
64
);
PrintMem(d,"D",
0
);
#endif
// add to LDS
auto
o_
a
lds_view
=
auto
o_lds_view
=
make_naive_tensor_view
<
address_space_enum
::
lds
,
memory_operation_enum
::
atomic_add
>
(
smem_0
,
make_tuple
(
number
<
32
>
{},
number
<
32
>
{}),
make_tuple
(
number
<
128
>
{},
number
<
32
>
{}),
make_tuple
(
32
,
1
),
number
<
8
>
{},
number
<
1
>
{});
auto
o_alds_win
=
make_tile_window
(
o_
a
lds_view
,
make_tuple
(
number
<
32
>
{},
number
<
32
>
{}),
{
0
,
0
});
make_tile_window
(
o_lds_view
,
make_tuple
(
number
<
128
>
{},
number
<
32
>
{}),
{
0
,
0
});
auto
o_olds_win
=
make_tile_window
(
o_
a
lds_view
,
make_tile_window
(
o_lds_view
,
make_tuple
(
number
<
32
>
{},
number
<
32
>
{}),
{
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_O
<
Problem
>());
...
...
@@ -278,31 +333,38 @@ struct FusedMoeGemmPipeline_General
gemm_1
(
o_acc
,
y
,
d
);
// block_sync_lds();
tile_elementwise_inout
(
[
&
topk_weight
](
auto
&
x
)
{
x
=
x
*
type_convert
<
float
>
(
topk_weight
);
},
o_acc
);
//
tile_elementwise_inout(
//
[&topk_weight](auto& x) { x = x * type_convert<float>(topk_weight); }, o_acc);
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
#if 0
PrintMem(o, "O", 65);
#endif
store_tile
(
o_alds_win
,
o
);
block_sync_lds
();
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0)
// {
// for(int i = 0; i < 42; i++)
// {
// printf("\n%d value is %f\t", i, type_convert<float>(smem_0[i]));
// }
// }
if
(
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
{
if
(
threadIdx
.
x
<
64
)
{
auto
o_out
=
load_tile
(
o_olds_win
);
block_sync_lds
();
store_tile
(
o_window_
,
o_out
);
auto
o0
=
load_tile
(
o_olds_win
);
for
(
int
step
=
1
;
step
<
4
;
step
++
)
{
move_tile_window
(
o_olds_win
,
{
32
,
0
});
auto
o1
=
load_tile
(
o_olds_win
);
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
o0
.
get_thread_buffer
()(
i
)
=
type_convert
<
ODataType
>
(
type_convert
<
float
>
(
o0
.
get_thread_buffer
()[
i
])
+
type_convert
<
float
>
(
o1
.
get_thread_buffer
()[
i
]));
}
}
update_tile
(
o_window_
,
o0
);
}
}
// ignore = o_olds_win;
// store_tile(o_window_, o);
#if 0
PrintMem(o,"O");
#endif
}
// store_tile(o_window_, a_dram_block);
}
};
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
b885995c
...
...
@@ -10,6 +10,8 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
...
...
@@ -198,7 +200,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
1
,
1
>>
{});
}
template
<
typename
Problem
>
...
...
@@ -214,13 +216,17 @@ struct FusedMoeGemmPipelineGeneralPolicy
typename
S_
::
WarpTile_0
>>
;
constexpr
auto
warp_gemm
=
GetWarpGemm0
<
Problem
>
();
using
BlockGemmPolicy
=
BlockGemmASmemBRegCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
// using BlockGemmPolicy =
// BlockGemmASmemBRegCRegV1CustomPolicy<typename
// Problem::ADataType,
typename
Problem
::
GDataType
,
typename
Problem
::
AccDataType
,
typename
S_
::
WarpPerBlock_0
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmASmemBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmASmemBSmemCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
// return BlockGemmASmemBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template
<
typename
Problem
>
...
...
@@ -288,28 +294,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
return
d_block_dstr
;
}
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
// {
// using S_ = remove_cvref_t<typename Problem::BlockShape>;
// using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// // using CDataType = typename WarpGemm::CDataType;
// constexpr auto c_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<>,
// tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
// sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
// tuple<sequence<1, 2>>,
// tuple<sequence<1, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
// constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
// return c_block_dstr;
// }
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsBlockDesc_A
()
{
...
...
@@ -322,7 +306,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kK0
>
{},
number
<
Block_M
>
{},
number
<
kK1
>
{}),
make_tuple
(
number
<
(
Block_M
+
1
)
*
kK1
>
{},
number
<
kK1
>
{},
number
<
1
>
{}),
make_tuple
(
number
<
Block_M
*
kK1
>
{},
number
<
kK1
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
...
...
@@ -333,9 +317,47 @@ struct FusedMoeGemmPipelineGeneralPolicy
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
// constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
// make_tuple(number<Block_M>{}, number<Block_K>{}),
// make_tuple(number<Block_K>{}, number<1>{}),
// number<8>{},
// number<1>{});
return
a_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsBlockDesc_G
()
{
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
Block_K
=
Problem
::
BlockShape
::
Block_K0
;
constexpr
index_t
kK1
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
kK0
=
Block_K
/
kK1
;
static_assert
(
Block_K
%
kK1
==
0
);
constexpr
auto
d_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kK0
>
{},
number
<
Block_N
>
{},
number
<
kK1
>
{}),
make_tuple
(
number
<
Block_N
*
kK1
>
{},
number
<
kK1
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
d_lds_block_desc
=
transform_tensor_descriptor
(
d_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
Block_N
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kK0
>
{},
number
<
kK1
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
// constexpr auto d_lds_block_desc = make_naive_tensor_descriptor(
// make_tuple(number<Block_N>{}, number<Block_K>{}),
// make_tuple(number<Block_K>{}, number<1>{}),
// number<8>{},
// number<1>{});
return
d_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsBlockDesc
()
{
...
...
@@ -343,11 +365,10 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
KVector
=
GetSmemKPack_Y
<
Problem
>
();
constexpr
index_t
KPad
=
0
;
constexpr
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
make_tuple
(
number
<
Block_N
+
KPad
>
{},
number
<
1
>
{}),
make_tuple
(
number
<
Block_N
>
{},
number
<
1
>
{}),
number
<
KVector
>
{},
number
<
1
>
{});
return
desc
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment