Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6916e3e4
Commit
6916e3e4
authored
Nov 30, 2022
by
rocking
Browse files
Clean the code
parent
ad2f82ac
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
149 additions
and
211 deletions
+149
-211
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+5
-8
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
...dwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
+0
-6
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
...mm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
+144
-197
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
6916e3e4
...
@@ -125,11 +125,11 @@ __global__ void
...
@@ -125,11 +125,11 @@ __global__ void
const
GammaDataType
*
__restrict__
p_gamma_grid
,
const
GammaDataType
*
__restrict__
p_gamma_grid
,
const
BetaDataType
*
__restrict__
p_beta_grid
,
const
BetaDataType
*
__restrict__
p_beta_grid
,
HDataType
*
__restrict__
p_h_grid
,
HDataType
*
__restrict__
p_h_grid
,
const
EHGridDesc_M_N
&
e_grid_desc_m_n
,
const
EHGridDesc_M_N
e_grid_desc_m_n
,
const
EHGridDesc_M_N
&
h_grid_desc_m_n
,
const
EHGridDesc_M_N
h_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_N
&
mean_var_count_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_N
mean_var_count_grid_desc_m_n
,
const
GammaBetaGridDesc_N
&
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
&
beta_grid_desc_n
,
const
GammaBetaGridDesc_N
beta_grid_desc_n
,
index_t
blkgroup_size
,
index_t
blkgroup_size
,
index_t
num_mean_var_count_k_block_tile_iteration
,
index_t
num_mean_var_count_k_block_tile_iteration
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_xy_k_block_tile_iteration
,
...
@@ -507,9 +507,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -507,9 +507,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_m_n_
=
mean_var_count_grid_desc_m_n_
=
DeviceOp
::
MakeMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
DeviceOp
::
MakeMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
int
s
=
mean_var_count_grid_desc_m_n_
.
GetElementSpaceSize
();
printf
(
"mean_var_count_grid_desc_m_n.GetElementSpaceSize() = %d
\n
"
,
s
);
hip_check_error
(
hipMalloc
(
&
p_e_grid_
,
sizeof
(
EDataType
)
*
MRaw
*
NRaw
));
hip_check_error
(
hipMalloc
(
&
p_e_grid_
,
sizeof
(
EDataType
)
*
MRaw
*
NRaw
));
int
gemm_welford_size
=
MRaw
*
gemm_nblock_
;
int
gemm_welford_size
=
MRaw
*
gemm_nblock_
;
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
View file @
6916e3e4
...
@@ -1080,12 +1080,6 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -1080,12 +1080,6 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
count_thread_buf
,
count_thread_buf
,
mean_var_count_grid_desc_mblock_mperblock_nblock
,
mean_var_count_grid_desc_mblock_mperblock_nblock
,
welford_count_grid_buf
);
welford_count_grid_buf
);
float
mean
=
static_cast
<
float
>
(
mean_thread_buf
(
I0
));
float
var
=
static_cast
<
float
>
(
var_thread_buf
(
I0
));
int
count
=
count_thread_buf
(
I0
);
if
(
i
==
0
&&
get_thread_global_1d_id
()
==
0
)
printf
(
"1st kernel mean = %f, var = %f, count = %d
\n
"
,
mean
,
var
,
count
);
});
});
}
// shuffle C + Ds + welford + write out
}
// shuffle C + Ds + welford + write out
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
View file @
6916e3e4
...
@@ -63,8 +63,12 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -63,8 +63,12 @@ struct GridwiseWelfordSecondHalfLayernorm2d
static
constexpr
auto
thread_cluster_desc
=
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_N
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths_M_N
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_1
=
decltype
(
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{})));
static
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
using
ThreadReduceSrcDesc_M_1
=
decltype
(
thread_buffer_desc_m_1
);
using
ThreadReduceDstDesc_M
=
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
...
@@ -124,201 +128,144 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -124,201 +128,144 @@ struct GridwiseWelfordSecondHalfLayernorm2d
ignore
=
numEBlockTileIteration_N
;
ignore
=
numEBlockTileIteration_N
;
ignore
=
epsilon
;
ignore
=
epsilon
;
// float mean = static_cast<float>(p_in_welford_mean_grid[0]);
// Thread/Block id
// float var = static_cast<float>(p_in_welford_var_grid[0]);
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
// int count = p_in_welford_count_grid[0];
const
index_t
block_global_id
=
get_block_1d_id
();
// if(get_thread_global_1d_id() == 0)
const
auto
thread_cluster_idx
=
// printf("kernel mean = %f, var = %f, count = %d\n", mean, var, count);
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
float
mean
=
static_cast
<
float
>
(
p_in_welford_mean_grid
[
0
]);
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
if
(
get_thread_global_1d_id
()
==
0
)
printf
(
"mean = %f
\n
"
,
mean
);
// step1: Merge mean and variance
auto
threadwise_mean_load_m_k
=
int
s
=
static_cast
<
int
>
(
mean_var_count_grid_desc_m_n
.
GetElementSpaceSize
());
ThreadwiseTensorSliceTransfer_v2
<
MeanDataType
,
if
(
get_thread_global_1d_id
()
==
0
)
ComputeDataType
,
printf
(
"mean_var_count_grid_desc_m_n.GetElementSpaceSize() = %d
\n
"
,
s
);
MeanVarCountGridDesc_M_N
,
decltype
(
thread_buffer_desc_m_1
),
// using ThreadBufferLengths_1_1 = Sequence<1, 1>;
ThreadBufferLengths_M_1
,
// constexpr auto thread_buffer_desc_1_1 =
Sequence
<
0
,
1
>
,
// make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
1
,
// constexpr auto grid_desc_1_1 =
1
,
// make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
1
,
true
>
(
// const auto mean_grid = make_dynamic_buffer<AddressSpaceEnum::Global>(
mean_var_count_grid_desc_m_n
,
// p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, 1, true> mean_thread;
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
// float mean1 = (mean_grid.template Get<MeanDataType>(0, true));
// if(get_thread_global_1d_id() == 0)
auto
threadwise_var_load_m_k
=
// printf("global mean = %f\n", mean1);
ThreadwiseTensorSliceTransfer_v2
<
VarDataType
,
ComputeDataType
,
// auto threadwise_mean_load_m_k =
MeanVarCountGridDesc_M_N
,
// ThreadwiseTensorSliceTransfer_v2<MeanDataType,
decltype
(
thread_buffer_desc_m_1
),
// ComputeDataType,
ThreadBufferLengths_M_1
,
// decltype(mean_var_count_grid_desc_m_n),
Sequence
<
0
,
1
>
,
// decltype(thread_buffer_desc_1_1),
1
,
// ThreadBufferLengths_1_1,
1
,
// Sequence<0, 1>,
1
,
// 1,
true
>
(
// 1,
mean_var_count_grid_desc_m_n
,
// 1,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
// true>(mean_var_count_grid_desc_m_n,
thread_m_cluster_id
*
MThreadSliceSize
,
// make_multi_index(0, 0));
thread_n_cluster_id
));
// threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
auto
threadwise_count_load_m_k
=
// mean_grid,
ThreadwiseTensorSliceTransfer_v2
<
int32_t
,
// thread_buffer_desc_1_1,
int32_t
,
// make_tuple(I0, I0),
MeanVarCountGridDesc_M_N
,
// mean_thread);
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
// if(get_thread_global_1d_id() == 0)
Sequence
<
0
,
1
>
,
// printf("threadwise mean = %f\n", mean_thread(Number<0>{}));
1
,
1
,
// // Thread/Block id
1
,
// const index_t thread_local_id = get_thread_local_1d_id();
true
>
(
// const index_t block_global_id = get_block_1d_id();
mean_var_count_grid_desc_m_n
,
// const auto thread_cluster_idx =
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
// thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
thread_m_cluster_id
*
MThreadSliceSize
,
// const auto thread_m_cluster_id = thread_cluster_idx[I0];
thread_n_cluster_id
));
// const auto thread_n_cluster_id = thread_cluster_idx[I1];
const
auto
welford_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
// // step1: Merge mean and variance
p_in_welford_mean_grid
,
mean_var_count_grid_desc_m_n
.
GetElementSpaceSize
());
// using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
// constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
const
auto
welford_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
// make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
p_in_welford_var_grid
,
mean_var_count_grid_desc_m_n
.
GetElementSpaceSize
());
// auto threadwise_mean_load_m_k =
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
// ThreadwiseTensorSliceTransfer_v2<MeanDataType,
p_in_welford_count_grid
,
mean_var_count_grid_desc_m_n
.
GetElementSpaceSize
());
// ComputeDataType,
// MeanVarCountGridDesc_M_N,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
// decltype(thread_buffer_desc_m_1),
in_welford_mean_thread_buf
;
// ThreadBufferLengths_M_1,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
// Sequence<0, 1>,
in_welford_var_thread_buf
;
// 1,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
// 1,
in_welford_count_thread_buf
;
// 1,
// true>(mean_var_count_grid_desc_m_n,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
// make_multi_index(0, 0));
welford_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
// auto threadwise_var_load_m_k =
welford_var_thread_buf
;
// ThreadwiseTensorSliceTransfer_v2<VarDataType,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
// ComputeDataType,
welford_count_thread_buf
;
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1),
constexpr
auto
mean_var_count_thread_copy_step_m_n
=
// ThreadBufferLengths_M_1,
make_multi_index
(
0
,
NThreadClusterSize
);
// Sequence<0, 1>,
// 1,
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
// 1,
welford_mean_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
// 1,
welford_var_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
// true>(
welford_count_thread_buf
(
I
)
=
0
;
// mean_var_count_grid_desc_m_n,
});
// make_multi_index(block_global_id * M_BlockTileSize +
// thread_m_cluster_id * MThreadSliceSize,
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
numMeanVarCountBlockTileIteration_N
;
// thread_n_cluster_id));
++
reducedTiles
)
{
// auto threadwise_count_load_m_k =
threadwise_mean_load_m_k
.
Run
(
mean_var_count_grid_desc_m_n
,
// ThreadwiseTensorSliceTransfer_v2<int32_t,
welford_mean_global_val_buf
,
// int32_t,
thread_buffer_desc_m_1
,
// MeanVarCountGridDesc_M_N,
make_tuple
(
I0
,
I0
),
// decltype(thread_buffer_desc_m_1),
in_welford_mean_thread_buf
);
// ThreadBufferLengths_M_1,
// Sequence<0, 1>,
threadwise_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_n
,
// 1,
welford_var_global_val_buf
,
// 1,
thread_buffer_desc_m_1
,
// 1,
make_tuple
(
I0
,
I0
),
// true>(
in_welford_var_thread_buf
);
// mean_var_count_grid_desc_m_n,
// make_multi_index(block_global_id * M_BlockTileSize +
threadwise_count_load_m_k
.
Run
(
mean_var_count_grid_desc_m_n
,
// thread_m_cluster_id * MThreadSliceSize,
welford_count_global_val_buf
,
// thread_n_cluster_id));
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
// const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
in_welford_count_thread_buf
);
// p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
ThreadwiseWelford
::
Run
(
in_welford_mean_thread_buf
,
// const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
in_welford_var_thread_buf
,
// p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
in_welford_count_thread_buf
,
welford_mean_thread_buf
,
// const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
welford_var_thread_buf
,
// p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
welford_count_thread_buf
);
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
threadwise_mean_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_n
,
// in_welford_mean_thread_buf;
mean_var_count_thread_copy_step_m_n
);
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
threadwise_var_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_n
,
// in_welford_var_thread_buf;
mean_var_count_thread_copy_step_m_n
);
// StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
threadwise_count_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_n
,
// in_welford_count_thread_buf;
mean_var_count_thread_copy_step_m_n
);
}
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// welford_mean_thread_buf;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
if
constexpr
(
I
>
0
)
// welford_var_thread_buf;
block_sync_lds
();
// StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
// welford_count_thread_buf;
BlockwiseWelford
::
Run
(
welford_mean_thread_buf
(
I
),
welford_var_thread_buf
(
I
),
welford_count_thread_buf
(
I
));
// constexpr auto mean_var_count_thread_copy_step_m_n =
});
// make_multi_index(0, NThreadClusterSize);
// step2: normalization
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
numEBlockTileIteration_N
;
++
reducedTiles
)
// welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
{
// welford_var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
// TODO
// welford_count_thread_buf(I) = 0;
}
// });
// for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N;
// ++reducedTiles)
// {
// threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
// welford_mean_global_val_buf,
// thread_buffer_desc_m_1,
// make_tuple(I0, I0),
// in_welford_mean_thread_buf);
// // threadwise_var_load_m_k.Run(mean_var_count_grid_desc_m_n,
// // welford_var_global_val_buf,
// // thread_buffer_desc_m_1,
// // make_tuple(I0, I0),
// // in_welford_var_thread_buf);
// // threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_n,
// // welford_count_global_val_buf,
// // thread_buffer_desc_m_1,
// // make_tuple(I0, I0),
// // in_welford_count_thread_buf);
// // ThreadwiseWelford::Run(in_welford_mean_thread_buf,
// // in_welford_var_thread_buf,
// // in_welford_count_thread_buf,
// // welford_mean_thread_buf,
// // welford_var_thread_buf,
// // welford_count_thread_buf);
// // threadwise_mean_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// // threadwise_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// // threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// if(get_thread_global_1d_id() == 0)
// printf("mean = %f, var = %f, count = %d\n",
// in_welford_mean_thread_buf(I),
// in_welford_var_thread_buf(I),
// in_welford_count_thread_buf(I));
// });
// }
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// if constexpr(I > 0)
// block_sync_lds();
// if(get_thread_global_1d_id() == 0)
// printf("count = %d\n", welford_count_thread_buf(I));
// BlockwiseWelford::Run(
// welford_mean_thread_buf(I), welford_var_thread_buf(I),
// welford_count_thread_buf(I));
// });
}
// run
}
// run
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment