Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1caa8198
Commit
1caa8198
authored
Nov 19, 2024
by
“letaoqin”
Browse files
write a, g,d and o tensor
parent
84755f74
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
113 additions
and
59 deletions
+113
-59
example/ck_tile/16_fused_moe_general/main.cpp
example/ck_tile/16_fused_moe_general/main.cpp
+2
-1
include/ck_tile/core/algorithm/indexing_adaptor.hpp
include/ck_tile/core/algorithm/indexing_adaptor.hpp
+40
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+47
-50
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+3
-1
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
+21
-7
No files found.
example/ck_tile/16_fused_moe_general/main.cpp
View file @
1caa8198
...
@@ -297,7 +297,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -297,7 +297,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
tokens
,
tokens
,
experts
,
experts
,
topk
,
topk
,
stride
};
stride
,
max_num_tokens_padded
};
float
ave_time
=
fused_moegemm
(
float
ave_time
=
fused_moegemm
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
...
...
include/ck_tile/core/algorithm/indexing_adaptor.hpp
View file @
1caa8198
...
@@ -57,4 +57,44 @@ struct indexing_adaptor_onshot_cached
...
@@ -57,4 +57,44 @@ struct indexing_adaptor_onshot_cached
return
ck_tile
::
is_known_at_compile_time
<
IndexingType
>::
value
;
return
ck_tile
::
is_known_at_compile_time
<
IndexingType
>::
value
;
}
}
};
};
template
<
typename
IndexingType
>
struct
indexing_adaptor
{
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor
(
const
IndexingType
*
idx
)
:
cached_idx_
(
idx
)
{}
const
IndexingType
*
cached_idx_
;
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
*
(
cached_idx_
+
idx_up
[
number
<
0
>
{}]);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
/*idx_low*/
,
const
UpIdx
&
/*idx_up*/
)
const
{
// TODO: nonthing changed here
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_diff_low
(
number
<
0
>
{})
=
idx_diff_up
[
number
<
0
>
{}];
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
IndexingType
>::
value
;
}
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
1caa8198
...
@@ -197,6 +197,7 @@ struct FusedMoeGemmGlKernel
...
@@ -197,6 +197,7 @@ struct FusedMoeGemmGlKernel
index_t
topk
;
// need this?
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
max_num_tokens_padded
;
// size of sorted_token_ids_ptr
};
};
// TODO: switch karg based on
// TODO: switch karg based on
...
@@ -230,17 +231,13 @@ struct FusedMoeGemmGlKernel
...
@@ -230,17 +231,13 @@ struct FusedMoeGemmGlKernel
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
index_t
nr_0
=
kargs
.
intermediate_size
/
BlockShape
::
Block_Nr0
;
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Block_Kr0
;
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShape
::
Block_Nr1
;
// should be same as kr_0
index_t
kr_1
=
kargs
.
intermediate_size
/
BlockShape
::
Block_Kr1
;
// should be same as nr_0
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
// note this is in unit of tile, need multiple tile size to get the index(i_m and i_n)
// note this is in unit of tile, need multiple tile size to get the index(block_m and
// block_n)
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
Partitioner
{}(
num_sorted_tiles
,
kargs
.
intermediate_size
);
Partitioner
{}(
num_sorted_tiles
,
kargs
.
intermediate_size
);
if
(
sorted_tile_id
>=
num_sorted_tiles
)
if
(
sorted_tile_id
>=
num_sorted_tiles
)
...
@@ -252,17 +249,28 @@ struct FusedMoeGemmGlKernel
...
@@ -252,17 +249,28 @@ struct FusedMoeGemmGlKernel
// index along intermediate_size
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
// BlockShape::Block_N0);
index_t
interm_idx_nr
=
index_t
idx_m0
=
__builtin_amdgcn_readfirstlane
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Nr0
);
index_t
idx_n0
=
__builtin_amdgcn_readfirstlane
(
sorted_tile_id
*
BlockShape
::
Block_N0
);
const
auto
a_coord
=
Pipeline
::
GetACoord
();
// 2d thread offset, [i_row, i_col]
// const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const
auto
sorted_token_id
=
a_coord
[
number
<
0
>
{}]
+
sorted_tile_id
*
BlockShape
::
Block_M0
;
// if(threadIdx.x == 200 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){
index_t
token_id
=
// printf("\n*************a_coord[0]: %d, a_coord[1]: %d size: %d \n",
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
)[
sorted_token_id
];
// a_coord[number<0>{}], a_coord[number<1>{}], a_coord.size());
// }
// const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id *
// BlockShape::Block_M0; //not block pos?
const
auto
sorted_token_id
=
sorted_tile_id
*
BlockShape
::
Block_M0
;
// start block_m
// position
// index_t token_id =
// reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto
topk_weight
=
auto
topk_weight
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
)[
sorted_token_id
];
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
)[
sorted_token_id
];
const
index_t
*
sorted_token_ids_ptr
=
reinterpret_cast
<
const
index_t
*>
(
&
(
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
)[
sorted_token_id
]));
const
auto
a_window
=
[
&
]()
{
const
auto
a_window
=
[
&
]()
{
// A is already pre-padded in previous kernel
// A is already pre-padded in previous kernel
const
ADataType
*
a_ptr
=
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
ADataType
*
a_ptr
=
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
...
@@ -276,7 +284,9 @@ struct FusedMoeGemmGlKernel
...
@@ -276,7 +284,9 @@ struct FusedMoeGemmGlKernel
// gather is here use indexing transform
// gather is here use indexing transform
const
auto
a_gather_view_
=
transform_tensor_view
(
const
auto
a_gather_view_
=
transform_tensor_view
(
a_view_
,
a_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_tuple
(
make_indexing_transform_with_adaptor
(
kargs
.
max_num_tokens_padded
,
indexing_adaptor
<
index_t
>
{
sorted_token_ids_ptr
}),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
@@ -284,7 +294,7 @@ struct FusedMoeGemmGlKernel
...
@@ -284,7 +294,7 @@ struct FusedMoeGemmGlKernel
const
auto
a_window_
=
make_tile_window
(
const
auto
a_window_
=
make_tile_window
(
a_gather_view_
,
a_gather_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
0
,
0
});
{
idx_m
0
,
0
});
return
a_window_
;
return
a_window_
;
}();
}();
...
@@ -292,52 +302,38 @@ struct FusedMoeGemmGlKernel
...
@@ -292,52 +302,38 @@ struct FusedMoeGemmGlKernel
const
auto
g_window
=
[
&
]()
{
const
auto
g_window
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
interm_
idx_n
r
*
k
r_0
*
BlockShape
::
Block_W0
;
idx_n
0
*
k
args
.
hidden_size
;
const
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}
),
make_tuple
(
BlockShape
::
Block_N0
,
kargs
.
hidden_size
),
make_tuple
(
k
r_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{}
,
1
),
make_tuple
(
k
args
.
hidden_size
,
1
),
number
<
Pipeline
::
kAlignmentG
>
{},
number
<
Pipeline
::
kAlignmentG
>
{},
number
<
1
>
{});
number
<
1
>
{});
const
auto
g_view_1_
=
pad_tensor_view
(
g_view_
,
const
auto
g_window_
=
make_tile_window
(
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
g_view_
,
number
<
BlockShape
::
Block_Kr0
>
{},
make_tuple
(
number
<
BlockShape
::
Block_N0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
number
<
BlockShape
::
Block_W0
>
{}),
{
0
,
0
});
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
const
auto
g_window_
=
make_tile_window
(
g_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
{
0
,
0
,
0
});
return
g_window_
;
return
g_window_
;
}();
}();
const
auto
d_window
=
[
&
]()
{
const
auto
d_window
=
[
&
]()
{
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
i
nterm_idx_nr
*
BlockShape
::
Block_W1
;
i
dx_n0
;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
BlockShape
::
Block_
W
1
),
make_tuple
(
kargs
.
hidden_size
,
BlockShape
::
Block_
K
1
),
make_tuple
(
k
r_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_W1
,
1
),
make_tuple
(
k
args
.
intermediate_size
,
1
),
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
number
<
1
>
{});
const
auto
d_view_1_
=
pad_tensor_view
(
d_view_
,
const
auto
d_window_
=
make_tile_window
(
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
d_view_
,
number
<
BlockShape
::
Block_Kr1
>
{},
make_tuple
(
number
<
BlockShape
::
Block_N1
>
{},
number
<
BlockShape
::
Block_K1
>
{}),
number
<
BlockShape
::
Block_W1
>
{}),
{
0
,
0
});
sequence
<
PadHiddenSize
,
PadIntermediateSize
,
0
>
{});
const
auto
d_window_
=
make_tile_window
(
d_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
number
<
BlockShape
::
Block_Kr1
>
{},
number
<
BlockShape
::
Block_W1
>
{}),
{
0
,
0
,
0
});
return
d_window_
;
return
d_window_
;
}();
}();
...
@@ -354,7 +350,9 @@ struct FusedMoeGemmGlKernel
...
@@ -354,7 +350,9 @@ struct FusedMoeGemmGlKernel
// gather is here
// gather is here
auto
o_scatter_view_
=
transform_tensor_view
(
auto
o_scatter_view_
=
transform_tensor_view
(
o_view_
,
o_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_tuple
(
make_indexing_transform_with_adaptor
(
kargs
.
max_num_tokens_padded
,
indexing_adaptor
<
index_t
>
{
sorted_token_ids_ptr
}),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
@@ -362,7 +360,7 @@ struct FusedMoeGemmGlKernel
...
@@ -362,7 +360,7 @@ struct FusedMoeGemmGlKernel
auto
o_window_
=
make_tile_window
(
auto
o_window_
=
make_tile_window
(
o_scatter_view_
,
o_scatter_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M1
>
{},
number
<
BlockShape
::
Block_N1
>
{}),
make_tuple
(
number
<
BlockShape
::
Block_M1
>
{},
number
<
BlockShape
::
Block_N1
>
{}),
{
0
,
0
});
{
idx_m
0
,
0
});
return
o_window_
;
return
o_window_
;
}();
}();
...
@@ -374,8 +372,7 @@ struct FusedMoeGemmGlKernel
...
@@ -374,8 +372,7 @@ struct FusedMoeGemmGlKernel
topk_weight
,
topk_weight
,
smem
,
smem
,
kargs
.
hidden_size
,
kargs
.
hidden_size
,
kargs
.
intermediate_size
,
kargs
.
intermediate_size
);
kargs
.
stride_token
);
}
}
};
};
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
1caa8198
...
@@ -104,7 +104,8 @@ struct FusedMoeGemmHostArgs
...
@@ -104,7 +104,8 @@ struct FusedMoeGemmHostArgs
index_t
num_experts
;
// number of groups
index_t
num_experts
;
// number of groups
index_t
topk
;
// need this?
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
max_num_tokens_padded
;
// size of sorted_token_ids_ptr
};
};
// This is scatter/gather b2b group-gemm
// This is scatter/gather b2b group-gemm
...
@@ -198,6 +199,7 @@ struct FusedMoeGemmKernel
...
@@ -198,6 +199,7 @@ struct FusedMoeGemmKernel
index_t
topk
;
// need this?
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
max_num_tokens_padded
;
// size of sorted_token_ids_ptr
};
};
// TODO: switch karg based on
// TODO: switch karg based on
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
View file @
1caa8198
...
@@ -80,16 +80,30 @@ struct FusedMoeGemmPipeline_FlatmmGl
...
@@ -80,16 +80,30 @@ struct FusedMoeGemmPipeline_FlatmmGl
return
max
(
smem_mat_a
,
smem_bridge
);
return
max
(
smem_mat_a
,
smem_bridge
);
}
}
template
<
typename
Karg
>
// this is the thread-offset along row/col
CK_TILE_DEVICE
auto
operator
()(
const
Karg
&
kargs
,
CK_TILE_HOST_DEVICE
static
auto
GetACoord
()
{
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
template
<
typename
AWindow
,
typename
GWindow
,
typename
DWindow
,
typename
OWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
AWindow
&
a_window_
,
const
GWindow
&
g_window_
,
const
DWindow
&
d_window_
,
OWindow
&
o_window_
,
TopkWeightDataType
/*topk_weight*/
,
CK_TILE_LDS_ADDR
void
*
smem
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
sorted_tile_id
,
index_t
hidden_size
,
index_t
intermediate_
tile_id
)
index_t
intermediate_
size
)
{
{
ignore
=
kargs
;
ignore
=
a_window_
;
ignore
=
g_window_
;
ignore
=
d_window_
;
ignore
=
o_window_
;
ignore
=
smem
;
ignore
=
smem
;
ignore
=
sorted_tile_id
;
ignore
=
hidden_size
;
ignore
=
intermediate_
tile_id
;
ignore
=
intermediate_
size
;
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment