Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a30c626b
Commit
a30c626b
authored
Sep 27, 2023
by
Bartlomiej Wroblewski
Browse files
Make ComputeDataType an optional argument
parent
b019d839
Changes
66
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
15 additions
and
17 deletions
+15
-17
library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp
...scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp
...scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp
...scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp
...scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp
+2
-2
profiler/include/profiler/profile_contraction_impl.hpp
profiler/include/profiler/profile_contraction_impl.hpp
+2
-2
test/contraction/test_contraction_interface.cpp
test/contraction/test_contraction_interface.cpp
+5
-7
No files found.
library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp
View file @
a30c626b
...
@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instanc
...
@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instanc
F64
,
F64
,
Empty_Tuple
,
Empty_Tuple
,
F64
,
F64
,
F64
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
)
Scale
,
F64
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance
{});
instances
,
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance
{});
...
...
library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp
View file @
a30c626b
...
@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instanc
...
@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instanc
F64
,
F64
,
Empty_Tuple
,
Empty_Tuple
,
F64
,
F64
,
F64
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
)
Scale
,
F64
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance
{});
instances
,
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance
{});
...
...
library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp
View file @
a30c626b
...
@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instanc
...
@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instanc
F64
,
F64
,
Empty_Tuple
,
Empty_Tuple
,
F64
,
F64
,
F64
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
)
Scale
,
F64
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance
{});
instances
,
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance
{});
...
...
library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp
View file @
a30c626b
...
@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
...
@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
F64
,
F64
,
Empty_Tuple
,
Empty_Tuple
,
F64
,
F64
,
F64
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
)
Scale
,
F64
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance
{});
instances
,
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance
{});
...
...
profiler/include/profiler/profile_contraction_impl.hpp
View file @
a30c626b
...
@@ -124,10 +124,10 @@ int profile_contraction_impl(ck::index_t do_verification,
...
@@ -124,10 +124,10 @@ int profile_contraction_impl(ck::index_t do_verification,
DataType
,
DataType
,
DTupleDataType
,
DTupleDataType
,
DataType
,
DataType
,
ComputeDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CDElementOp
>
;
CDElementOp
,
ComputeDataType
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
...
test/contraction/test_contraction_interface.cpp
View file @
a30c626b
...
@@ -75,7 +75,6 @@ template <typename DataTypeA,
...
@@ -75,7 +75,6 @@ template <typename DataTypeA,
typename
DataTypeB
,
typename
DataTypeB
,
typename
DataTypeC
,
typename
DataTypeC
,
typename
DataTypeD
,
typename
DataTypeD
,
typename
DataTypeCompute
,
ck
::
index_t
NumDim
>
ck
::
index_t
NumDim
>
class
ContractionDeviceOpWrapper
class
ContractionDeviceOpWrapper
{
{
...
@@ -88,7 +87,6 @@ class ContractionDeviceOpWrapper
...
@@ -88,7 +87,6 @@ class ContractionDeviceOpWrapper
DataTypeB
,
DataTypeB
,
ck
::
Tuple
<
DataTypeC
>
,
ck
::
Tuple
<
DataTypeC
>
,
DataTypeD
,
DataTypeD
,
DataTypeCompute
,
Pass
,
Pass
,
Pass
,
Pass
,
Bilinear
>
;
Bilinear
>
;
...
@@ -131,9 +129,9 @@ TEST(TestContractionInterface, IncorrectNumDims)
...
@@ -131,9 +129,9 @@ TEST(TestContractionInterface, IncorrectNumDims)
{
{
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Dims
=
{{
4
,
4
},
{
4
,
4
,
4
,
4
},
{
4
,
4
,
4
,
4
,
4
,
4
}};
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Dims
=
{{
4
,
4
},
{
4
,
4
,
4
,
4
},
{
4
,
4
,
4
,
4
,
4
,
4
}};
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Strides
=
{{
1
,
1
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
,
1
}};
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Strides
=
{{
1
,
1
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
,
1
}};
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
F32
,
1
>
wrapper_1d
;
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
1
>
wrapper_1d
;
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
F32
,
2
>
wrapper_2d
;
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
2
>
wrapper_2d
;
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
F32
,
3
>
wrapper_3d
;
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
3
>
wrapper_3d
;
EXPECT_FALSE
(
wrapper_1d
.
IsSupportedInstance
(
Dims
[
0
],
Strides
[
0
]));
EXPECT_FALSE
(
wrapper_1d
.
IsSupportedInstance
(
Dims
[
0
],
Strides
[
0
]));
EXPECT_TRUE
(
wrapper_2d
.
IsSupportedInstance
(
Dims
[
1
],
Strides
[
1
]));
EXPECT_TRUE
(
wrapper_2d
.
IsSupportedInstance
(
Dims
[
1
],
Strides
[
1
]));
EXPECT_FALSE
(
wrapper_3d
.
IsSupportedInstance
(
Dims
[
2
],
Strides
[
2
]));
EXPECT_FALSE
(
wrapper_3d
.
IsSupportedInstance
(
Dims
[
2
],
Strides
[
2
]));
...
@@ -143,8 +141,8 @@ TEST(TestContractionInterface, IncorrectDataTypes)
...
@@ -143,8 +141,8 @@ TEST(TestContractionInterface, IncorrectDataTypes)
{
{
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Strides
=
{
64
,
16
,
4
,
1
};
std
::
vector
<
ck
::
index_t
>
Strides
=
{
64
,
16
,
4
,
1
};
ContractionDeviceOpWrapper
<
F32
,
F32
,
F64
,
F64
,
F32
,
2
>
wrapper_1
;
ContractionDeviceOpWrapper
<
F32
,
F32
,
F64
,
F64
,
2
>
wrapper_1
;
ContractionDeviceOpWrapper
<
F64
,
F64
,
F32
,
F32
,
F32
,
2
>
wrapper_2
;
ContractionDeviceOpWrapper
<
F64
,
F64
,
F32
,
F32
,
2
>
wrapper_2
;
EXPECT_FALSE
(
wrapper_1
.
IsSupportedInstance
(
Dims
,
Strides
));
EXPECT_FALSE
(
wrapper_1
.
IsSupportedInstance
(
Dims
,
Strides
));
EXPECT_FALSE
(
wrapper_2
.
IsSupportedInstance
(
Dims
,
Strides
));
EXPECT_FALSE
(
wrapper_2
.
IsSupportedInstance
(
Dims
,
Strides
));
}
}
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment