Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ac6977f7
"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "307c524bac267c99a93dea5a694acc11f4b1536f"
Commit
ac6977f7
authored
May 29, 2022
by
Anthony Chang
Browse files
tidy up
parent
2d91fd12
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
107 deletions
+5
-107
include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp
...eration/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp
+0
-27
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
...tion/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
+5
-80
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp
View file @
ac6977f7
...
@@ -336,32 +336,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
...
@@ -336,32 +336,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
}
}
}
}
// assuming packed tensor
static
auto
MakeGridDescriptor_M
(
index_t
MRaw
)
{
const
auto
grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M
return
transform_tensor_descriptor
(
grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
else
{
// not pad M
return
grid_desc_mraw
;
}
}
static
auto
MakeGridDescriptor_N
(
index_t
NRaw
)
static
auto
MakeGridDescriptor_N
(
index_t
NRaw
)
{
{
const
auto
grid_desc_nraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NRaw
));
const
auto
grid_desc_nraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NRaw
));
...
@@ -604,7 +578,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
...
@@ -604,7 +578,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
typename
GridwiseGemm
::
C0GridDescriptor_NBlock_NPerBlock
,
typename
GridwiseGemm
::
C0GridDescriptor_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
false
>
;
ave_time
=
ave_time
=
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
View file @
ac6977f7
...
@@ -182,9 +182,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -182,9 +182,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
// 1 * MWave * 32
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
// 1 * NWave * 32
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
}
...
@@ -296,24 +296,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -296,24 +296,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
}
// for broadcasting bias, beta, gamma
// __host__ __device__ static constexpr auto
// MakeC0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const C0GridDescriptor_NBlock_NPerBlock& c0_grid_desc_nblock_nperblock)
// {
// const auto NBlock = c0_grid_desc_nblock_nperblock.GetLength(I0);
// const auto c0_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
// c0_grid_desc_nblock_nperblock,
// make_tuple(make_insert_transform(I1),
// make_insert_transform(I1),
// make_pass_through_transform(NBlock),
// make_pass_through_transform(NPerBlock)),
// make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// return c0_grid_desc_mblock_mperblock_nblock_nperblock;
// }
// for bias, beta, gamma
// for bias, beta, gamma
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeC0GridDescriptor_NBlock_NPerBlock
(
const
C0GridDesc_N
&
c0_grid_desc_n
)
MakeC0GridDescriptor_NBlock_NPerBlock
(
const
C0GridDesc_N
&
c0_grid_desc_n
)
...
@@ -411,16 +393,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -411,16 +393,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto
c0_beta_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c0_beta_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c0_beta_grid
,
c0_grid_desc_nblock_nperblock
.
GetElementSpaceSize
());
p_c0_beta_grid
,
c0_grid_desc_nblock_nperblock
.
GetElementSpaceSize
());
// if (hipThreadIdx_x == 0 && hipBlockIdx_x == 0) c_grid_desc_mblock_mperblock_nblock_nperblock.Print();
/*
{TensorDescriptor,
transforms: {Embed, up_lengths_ {MultiIndex, size 2,256 128 }coefficients_ {MultiIndex, size 2,128 1 }}LowerDimensionIds:{size 1, 0 }UpperDimensionIds:{size 2, 1 2 }
transforms: {UnMerge, up_lengths_{MultiIndex, size 2,1 256 }up_lengths_scan_{MultiIndex, size 2,256 1 }}LowerDimensionIds:{size 1, 1 }UpperDimensionIds:{size 2, 3 4 }
transforms: {UnMerge, up_lengths_{MultiIndex, size 2,1 128 }up_lengths_scan_{MultiIndex, size 2,128 1 }}LowerDimensionIds:{size 1, 2 }UpperDimensionIds:{size 2, 5 6 }
}
{size 4, 3 4 5 6 }
*/
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
@@ -891,7 +863,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -891,7 +863,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
// make sure it's safe to write to LDS
// __syncthreads();
block_sync_lds
();
block_sync_lds
();
// each thread write its data from VGPR to LDS
// each thread write its data from VGPR to LDS
...
@@ -901,9 +872,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -901,9 +872,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
// debug::print_shared(c_shuffle_block_buf.p_data_, c_shuffle_block_buf.element_space_size_);
// __syncthreads();
block_sync_lds
();
block_sync_lds
();
// layernorm
// layernorm
...
@@ -924,34 +892,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -924,34 +892,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c0_thread_buf
);
c0_thread_buf
);
static_for
<
0
,
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSize
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSize
(),
1
>
{}([
&
](
auto
i
)
{
// auto thread_slice_desc = make_cluster_descriptor(
// Sequence<mreduce_per_thread, nreduce_per_thread>{});
// auto thread_slice_idx = thread_slice_desc.CalculateBottomIndex(make_multi_index(i));
// printf("tid %zd, access_id %d, im, in %d, %d, c0 = %f, c = %f\n",
// hipThreadIdx_x,
// access_id.value,
// thread_slice_idx[I0],
// thread_slice_idx[I1],
// c0_thread_buf(i),
// c_reduce_thread_buf(i));
c_reduce_thread_buf
(
i
)
+=
c0_thread_buf
(
i
);
c_reduce_thread_buf
(
i
)
+=
c0_thread_buf
(
i
);
});
});
// static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
// static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
// constexpr auto offset =
// Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
// make_tuple(im, in))>{};
// c_reduce_thread_buf(offset) += c0_thread_buf(offset);
// // printf("tid %zd, access_id %d, im, in %d, %d, c0 = %f, c+c0 = %f\n",
// // hipThreadIdx_x,
// // access_id.value,
// // im.value,
// // in.value,
// // c0_thread_buf(offset),
// // c_reduce_thread_buf(offset));
// });
// });
using
ThreadwiseReduceD0
=
using
ThreadwiseReduceD0
=
ThreadwiseReduction
<
FloatReduceAcc
,
ThreadwiseReduction
<
FloatReduceAcc
,
...
@@ -979,7 +921,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -979,7 +921,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// reduce squared sum in VGPR
// reduce squared sum in VGPR
ThreadwiseReduceD1
::
Reduce
(
c_reduce_thread_buf
,
d1_thread_buf
);
ThreadwiseReduceD1
::
Reduce
(
c_reduce_thread_buf
,
d1_thread_buf
);
// reduce
across
workg
o
rup
// reduce
within
workgr
o
up
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
FloatReduceAcc
,
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
FloatReduceAcc
,
BlockSize
,
BlockSize
,
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
// ThreadClusterLengths_M_K
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
// ThreadClusterLengths_M_K
...
@@ -992,17 +934,11 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -992,17 +934,11 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BlockwiseReduce
::
Reduce
(
d_reduce_work_buf
,
d0_thread_buf
(
i
));
// blockwise reduced sum
BlockwiseReduce
::
Reduce
(
d_reduce_work_buf
,
d0_thread_buf
(
i
));
// blockwise reduced sum
block_sync_lds
();
block_sync_lds
();
BlockwiseReduce
::
Reduce
(
d_reduce_work_buf
,
d1_thread_buf
(
i
));
// blockwise reduced squared sum
BlockwiseReduce
::
Reduce
(
d_reduce_work_buf
,
d1_thread_buf
(
i
));
// blockwise reduced squared sum
// printf("tid %zd, access_id %d, mreduce_idx %d, sum = %f, sq sum = %f\n",
// hipThreadIdx_x,
// access_id.value,
// i.value,
// d0_thread_buf(i),
// d1_thread_buf(i));
});
});
// normalize
// normalize
const
index_t
NRaw
=
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I1
];
// TODO: proper handle
const
index_t
NRaw
=
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I1
];
// TODO: proper handle
// if(hipThreadIdx_x == 0) printf("NRaw = %d\n", NRaw);
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
static_for
<
0
,
nreduce_per_thread
,
1
>
{}([
&
](
auto
in
)
{
static_for
<
0
,
nreduce_per_thread
,
1
>
{}([
&
](
auto
in
)
{
constexpr
auto
dst_offset
=
constexpr
auto
dst_offset
=
...
@@ -1022,14 +958,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -1022,14 +958,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
tensor_operation
::
element_wise
::
UnarySqrt
<
FloatReduceAcc
,
FloatReduceAcc
>
{}(
divisor_sqrt
,
divisor
);
tensor_operation
::
element_wise
::
UnarySqrt
<
FloatReduceAcc
,
FloatReduceAcc
>
{}(
divisor_sqrt
,
divisor
);
c_reduce_thread_buf
(
dst_offset
)
=
denom
/
divisor_sqrt
;
c_reduce_thread_buf
(
dst_offset
)
=
denom
/
divisor_sqrt
;
// printf("tid %zd, access_id %d, reduce_idx %d %d, avg_sum = %f, avg sq sum = %f, final = %f\n",
// hipThreadIdx_x,
// access_id.value,
// im.value,
// in.value,
// avg_sum,
// avg_squared_sum,
// c_reduce_thread_buf(dst_offset));
});
});
});
});
...
@@ -1056,7 +984,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -1056,7 +984,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_reduce_thread_buf
(
i
)
+=
c0_thread_buf
(
i
);
// + beta
c_reduce_thread_buf
(
i
)
+=
c0_thread_buf
(
i
);
// + beta
});
});
// __syncthreads();
block_sync_lds
();
block_sync_lds
();
c_reduce_thread_copy_vgpr_to_lds
.
Run
(
c_reduce_thread_desc_mperblock_nperblock
,
c_reduce_thread_copy_vgpr_to_lds
.
Run
(
c_reduce_thread_desc_mperblock_nperblock
,
...
@@ -1067,9 +994,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -1067,9 +994,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// end layernorm
}
// end layernorm
// __syncthreads();
block_sync_lds
();
block_sync_lds
();
// debug::print_shared<32>(c_shuffle_block_buf.p_data_, c_shuffle_block_buf.element_space_size_);
// each block copy its data from LDS to global
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_copy_lds_to_global
.
Run
(
...
@@ -1086,7 +1011,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -1086,7 +1011,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
// move on C0
bias
// move on C0
c0_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
c0_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment