Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
24e18ae8
Commit
24e18ae8
authored
Oct 15, 2024
by
Jing Zhang
Browse files
fixed coord reset
parent
c3a4652a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
41 additions
and
46 deletions
+41
-46
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
+1
-2
example/01_gemm/run_gemm_example_v2.inc
example/01_gemm/run_gemm_example_v2.inc
+8
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+3
-6
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+2
-1
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+27
-29
No files found.
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
View file @
24e18ae8
...
@@ -48,10 +48,9 @@ using DeviceGemmV2Instance =
...
@@ -48,10 +48,9 @@ using DeviceGemmV2Instance =
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
1
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v1
>
;
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v1
>
;
#endif
#endif
// clang-format on
// clang-format on
...
...
example/01_gemm/run_gemm_example_v2.inc
View file @
24e18ae8
...
@@ -224,14 +224,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -224,14 +224,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
// get_rtol<CDataType>(),
// get_rtol<CDataType>(),
// get_atol<CDataType>());
// get_atol<CDataType>());
for
(
int
i
=
0
;
i
<
M
;
i
++
)
//
for(int i = 0; i < M; i++)
{
//
{
for
(
int
j
=
0
;
j
<
N
;
j
++
)
//
for(int j = 0; j < N; j++)
{
//
{
std
::
cout
<<
ck
::
type_convert
<
float
>
(
c_m_n_device_result
(
i
,
j
))
<<
","
;
//
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
}
//
}
std
::
cout
<<
std
::
endl
;
//
std::cout << std::endl;
}
//
}
}
}
if
(
config
.
time_kernel
)
if
(
config
.
time_kernel
)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
24e18ae8
...
@@ -775,7 +775,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -775,7 +775,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
// NLdsLayer * K0 as logical Bank
// NLdsLayer * K0 as logical Bank
constexpr
index_t
LdsSize
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
);
constexpr
index_t
LdsSize
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
);
constexpr
auto
NLdsLayer
=
LdsSize
<
1
?
1
:
LdsSize
;
constexpr
index_t
NLdsLayer
=
LdsSize
<
1
?
1
:
LdsSize
;
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
make_tuple
(
BK0Number
*
Number
<
NLdsLayer
>
{},
Number
<
NPerBlock
/
NLdsLayer
>
{},
BK1Number
),
BK0Number
*
Number
<
NLdsLayer
>
{},
Number
<
NPerBlock
/
NLdsLayer
>
{},
BK1Number
),
...
@@ -1318,17 +1318,14 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1318,17 +1318,14 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// Cast after lds
// Cast after lds
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ADataType
*>
(
p_shared
),
a_block_
space_size_aligned
);
static_cast
<
ADataType
*>
(
p_shared
),
a_block_
desc_ak0_m_ak1
.
GetElementSpaceSize
()
);
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
reinterpret_cast
<
BDataType
*>
(
static_cast
<
char
*>
(
p_shared
)
+
reinterpret_cast
<
BDataType
*>
(
static_cast
<
char
*>
(
p_shared
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
APackedSize
),
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
APackedSize
),
b_block_
s
pace
_s
ize
_aligned
);
b_block_
desc_bk0_n_bk1
.
GetElementS
pace
S
ize
()
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1Number
,
0
,
0
);
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
24e18ae8
...
@@ -1133,12 +1133,13 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1133,12 +1133,13 @@ struct ThreadwiseTensorSliceTransfer_v4
}
}
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
{
{
static_assert
(
false
,
""
);
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
+
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
i
*
src_scalar_step_in_vector
);
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
/
PackedSize
>
{}];
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
});
});
}
}
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
24e18ae8
...
@@ -185,10 +185,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -185,10 +185,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
[
&
](
auto
i
)
{
return
Number
<
src_data_idx
[
i
]
>
{};
},
Number
<
src_data_idx
.
Size
()
>
{});
[
&
](
auto
i
)
{
return
Number
<
src_data_idx
[
i
]
>
{};
},
Number
<
src_data_idx
.
Size
()
>
{});
// maintain a container record is_src_valid, waiting for RunWrite use.
// maintain a container record is_src_valid, waiting for RunWrite use.
const
bool
is_src_valid
=
//
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
//
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
src_oob_thread_scratch_tuple_
(
thread_scratch_id
)
//
src_oob_thread_scratch_tuple_(thread_scratch_id)
.
template
SetAsType
<
bool
>(
src_data_idx_seq
,
is_src_valid
);
//
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
...
@@ -347,13 +347,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -347,13 +347,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using
vector_t
=
typename
vector_type_maker
<
DstData
,
SrcScalarPerVector
>::
type
::
type
;
using
vector_t
=
typename
vector_type_maker
<
DstData
,
SrcScalarPerVector
>::
type
::
type
;
auto
op_r
=
src_thread_scratch_tuple_
(
thread_scratch_id
)
auto
op_r
_v
=
src_thread_scratch_tuple_
(
thread_scratch_id
)
.
template
GetAsType
<
vector_t
>(
src_data_idx_seq
);
.
template
GetAsType
<
vector_t
>(
src_data_idx_seq
);
const
bool
is_src_valid
=
src_oob_thread_scratch_tuple_
(
thread_scratch_id
)
//
const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
.
template
GetAsType
<
bool
>(
src_data_idx_seq
);
//
.template GetAsType<bool>(src_data_idx_seq);
auto
op_r_v
=
is_src_valid
?
op_r
:
vector_t
(
0
);
//
auto op_r_v = is_src_valid ? op_r : vector_t(0);
src_thread_scratch_tuple_
(
thread_scratch_id
)
src_thread_scratch_tuple_
(
thread_scratch_id
)
.
template
SetAsType
<
vector_t
>(
src_data_idx_seq
,
op_r_v
);
.
template
SetAsType
<
vector_t
>(
src_data_idx_seq
,
op_r_v
);
...
@@ -537,8 +537,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -537,8 +537,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
dst_data_idx_seq
=
generate_sequence_v2
(
constexpr
auto
dst_data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
dst_data_idx
[
i
]
>
{};
},
Number
<
dst_data_idx
.
Size
()
>
{});
[
&
](
auto
i
)
{
return
Number
<
dst_data_idx
[
i
]
>
{};
},
Number
<
dst_data_idx
.
Size
()
>
{});
const
bool
is_dst_valid
=
//
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
//
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
...
@@ -552,15 +552,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -552,15 +552,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// apply DstElementwiseOperation
// apply DstElementwiseOperation
dst_element_op_
(
dst_v
,
dst_vector_container
.
template
AsType
<
DstData
>()[
i
]);
dst_element_op_
(
dst_v
,
dst_vector_container
.
template
AsType
<
DstData
>()[
i
]);
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
dst_v
;
});
});
// copy data from dst_vector_container to dst_buf
// copy data from dst_vector_container to dst_buf
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
()
/
PackedSize
,
dst_coord_
.
GetOffset
()
/
PackedSize
,
is_dst_valid
,
true
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
{
...
@@ -612,7 +610,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -612,7 +610,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -670,7 +668,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -670,7 +668,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
// TODO: don't use lambda_scalar_per_access
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
@@ -756,12 +754,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -756,12 +754,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
{
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths_and_vector_length
=
container_push_back
(
constexpr
auto
src_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
src_access_lengths
),
Number
<
SrcScalarPerVector
>
{});
sequence_to_tuple_of_number
(
src_access_lengths
),
Number
<
SrcScalarPerVector
_
>
{});
// 1st stage of transforms
// 1st stage of transforms
constexpr
auto
desc0
=
constexpr
auto
desc0
=
...
@@ -805,7 +803,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -805,7 +803,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__
static
constexpr
auto
GetSrcOOBThreadScratchDescriptor
()
__device__
static
constexpr
auto
GetSrcOOBThreadScratchDescriptor
()
{
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -816,12 +814,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -816,12 +814,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{
{
// 1st stage of transforms
// 1st stage of transforms
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths_and_vector_length
=
container_push_back
(
constexpr
auto
dst_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
dst_access_lengths
),
Number
<
DstScalarPerVector
>
{});
sequence_to_tuple_of_number
(
dst_access_lengths
),
Number
<
DstScalarPerVector
_
>
{});
constexpr
auto
desc0
=
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
dst_access_lengths_and_vector_length
);
make_naive_tensor_descriptor_packed
(
dst_access_lengths_and_vector_length
);
...
@@ -874,12 +872,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -874,12 +872,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
decltype
(
src_thread_scratch_desc_
),
decltype
(
src_thread_scratch_desc_
),
true
>
;
true
>
;
using
SrcOOBThreadScratch
=
//
using SrcOOBThreadScratch =
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
//
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
bool
,
// apply data_convert with SrcThreadScratch
//
bool, // apply data_convert with SrcThreadScratch
1
,
//
1,
decltype
(
src_oob_thread_scratch_desc_
),
//
decltype(src_oob_thread_scratch_desc_),
true
>
;
//
true>;
using
DstThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
using
DstThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
DstData
,
...
@@ -888,7 +886,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -888,7 +886,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
true
>
;
true
>
;
StaticallyIndexedArray
<
SrcThreadScratch
,
NumThreadScratch
>
src_thread_scratch_tuple_
;
StaticallyIndexedArray
<
SrcThreadScratch
,
NumThreadScratch
>
src_thread_scratch_tuple_
;
StaticallyIndexedArray
<
SrcOOBThreadScratch
,
NumThreadScratch
>
src_oob_thread_scratch_tuple_
;
//
StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
DstThreadScratch
dst_thread_scratch_
;
DstThreadScratch
dst_thread_scratch_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment