Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
11b83234
Commit
11b83234
authored
May 08, 2022
by
Chao Liu
Browse files
clean up
parent
8e3aef3b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
36 deletions
+16
-36
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp
+3
-23
test/CMakeLists.txt
test/CMakeLists.txt
+1
-1
test/gemm/gemm_fp16.cpp
test/gemm/gemm_fp16.cpp
+12
-12
No files found.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp
View file @
11b83234
...
@@ -284,8 +284,7 @@ struct ThreadwiseTensorSliceTransfer_v3r3
...
@@ -284,8 +284,7 @@ struct ThreadwiseTensorSliceTransfer_v3r3
// TODO make this logic more generic for more sub-dword datatype
// TODO make this logic more generic for more sub-dword datatype
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
(
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
||
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
is_same
<
bhalf_t
,
remove_cvref_t
<
DstData
>>::
value
)
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
{
{
// each transpose does
// each transpose does
...
@@ -344,27 +343,8 @@ struct ThreadwiseTensorSliceTransfer_v3r3
...
@@ -344,27 +343,8 @@ struct ThreadwiseTensorSliceTransfer_v3r3
// do data transpose
// do data transpose
// TODO type_convert is not used yet!!!!!
// TODO type_convert is not used yet!!!!!
transpose_convert_vectors
<
SrcData
,
transpose_vectors
<
SrcData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
DstData
,
src_vector_refs
,
dst_vector_refs
);
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
});
}
else
if
constexpr
(
SrcVectorDim
==
DstVectorDim
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
&&
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
bhalf_t
,
remove_cvref_t
<
DstData
>>::
value
)
{
auto
NewSliceLengths
=
SliceLengths
{}.
template
Modify
(
Number
<
SrcVectorDim
>{},
Number
<
SliceLengths
{}[
SrcVectorDim
]
/
2
>
{});
auto
VectorStep
=
SliceLengths
{}
/
NewSliceLengths
;
static_ford
<
decltype
(
NewSliceLengths
)
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
auto
nidx
=
idx
*
VectorStep
;
auto
vhalf
=
src_thread_scratch_tuple_
[
thread_scratch_id
].
template
GetAsType
<
half2_t
>(
nidx
);
dst_thread_scratch_
.
template
SetAsType
<
bhalf2_t
>(
nidx
,
type_convert
<
bhalf2_t
>
(
vhalf
));
});
});
}
}
else
else
...
...
test/CMakeLists.txt
View file @
11b83234
...
@@ -59,4 +59,4 @@ add_subdirectory(batched_gemm_reduce)
...
@@ -59,4 +59,4 @@ add_subdirectory(batched_gemm_reduce)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
convnd_fwd
)
add_subdirectory
(
convnd_fwd
)
add_subdirectory
(
reduce
)
add_subdirectory
(
reduce
)
#
add_subdirectory(conv2d_bwd_weight)
add_subdirectory
(
conv2d_bwd_weight
)
test/gemm/gemm_fp16.cpp
View file @
11b83234
...
@@ -33,10 +33,10 @@ void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNo
...
@@ -33,10 +33,10 @@ void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNo
void
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
//
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
//
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
//
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
//
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
...
@@ -63,8 +63,8 @@ int main()
...
@@ -63,8 +63,8 @@ int main()
std
::
vector
<
DeviceGemmNoOpPtr
>
gemmPtrs
;
std
::
vector
<
DeviceGemmNoOpPtr
>
gemmPtrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
gemmPtrs
);
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
gemmPtrs
);
//
ck::tensor_operation::device::device_gemm_instance::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
//
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances
(
gemmPtrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
gemmPtrs
);
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
gemmPtrs
);
...
@@ -85,8 +85,8 @@ int main()
...
@@ -85,8 +85,8 @@ int main()
gemmPtrs
.
clear
();
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
gemmPtrs
);
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
gemmPtrs
);
//
ck::tensor_operation::device::device_gemm_instance::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
//
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
gemmPtrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
gemmPtrs
);
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
gemmPtrs
);
...
@@ -107,8 +107,8 @@ int main()
...
@@ -107,8 +107,8 @@ int main()
gemmPtrs
.
clear
();
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
gemmPtrs
);
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
gemmPtrs
);
//
ck::tensor_operation::device::device_gemm_instance::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
//
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
gemmPtrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
gemmPtrs
);
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
gemmPtrs
);
...
@@ -129,8 +129,8 @@ int main()
...
@@ -129,8 +129,8 @@ int main()
gemmPtrs
.
clear
();
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
gemmPtrs
);
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
gemmPtrs
);
//
ck::tensor_operation::device::device_gemm_instance::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
//
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
gemmPtrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
gemmPtrs
);
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
gemmPtrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment