Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e30e7a8c
"...composable_kernel_rocm.git" did not exist on "6d4450ef155c39af9ede2cd171be40ee06db9939"
Commit
e30e7a8c
authored
Nov 27, 2023
by
muozturk
Browse files
test case for complex contraction bilinear
parent
b45dd4d6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
332 additions
and
0 deletions
+332
-0
test/complex_contraction_bilinear/CMakeLists.txt
test/complex_contraction_bilinear/CMakeLists.txt
+13
-0
test/complex_contraction_bilinear/test_complex_contraction_bilinear.cpp
...ontraction_bilinear/test_complex_contraction_bilinear.cpp
+124
-0
test/complex_contraction_bilinear/test_complex_contraction_bilinear_interface.cpp
..._bilinear/test_complex_contraction_bilinear_interface.cpp
+195
-0
No files found.
test/complex_contraction_bilinear/CMakeLists.txt
0 → 100755
View file @
e30e7a8c
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
((
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
)
OR NOT DEFINED DTYPES
)
add_gtest_executable
(
test_complex_contraction_bilinear test_complex_contraction_bilinear.cpp
)
target_link_libraries
(
test_complex_contraction_bilinear PRIVATE utility device_contraction_bilinear_instance
)
add_gtest_executable
(
test_complex_contraction_bilinear_interface test_complex_contraction_bilinear_interface.cpp
)
target_link_libraries
(
test_complex_contraction_bilinear_interface PRIVATE utility device_contraction_bilinear_instance
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
test/complex_contraction_bilinear/test_complex_contraction_bilinear.cpp
0 → 100755
View file @
e30e7a8c
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <memory>
#include <initializer_list>
#include <vector>
#include <tuple>
#include <gtest/gtest.h>
#include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp"
using
F32
=
float
;
using
F64
=
double
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
struct
Dimensions
{
std
::
vector
<
ck
::
index_t
>
M
;
std
::
vector
<
ck
::
index_t
>
N
;
std
::
vector
<
ck
::
index_t
>
K
;
};
template
<
typename
Tuple
>
class
TestContraction
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CDLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
DataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
DTupleDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
ComputeDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDElementOp
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
std
::
vector
<
Dimensions
>
dimension_list
=
{{{
32
,
32
},
{
32
,
32
},
{
32
,
32
}},
{{
16
,
16
},
{
32
,
32
},
{
16
,
16
}}};
std
::
vector
<
ck
::
index_t
>
init_methods
=
{
1
,
2
};
std
::
unique_ptr
<
CDElementOp
>
p_cd_element_op
;
void
Run
()
{
for
(
auto
&
dimension_params
:
dimension_list
)
{
std
::
vector
<
ck
::
index_t
>
StridesA
;
std
::
vector
<
ck
::
index_t
>
StridesB
;
std
::
vector
<
ck
::
index_t
>
StridesC
;
std
::
vector
<
ck
::
index_t
>
StridesD
;
const
auto
&
M
=
dimension_params
.
M
;
const
auto
&
N
=
dimension_params
.
N
;
const
auto
&
K
=
dimension_params
.
K
;
assign_default_strides
(
ALayout
{},
StridesA
,
{
M
[
0
],
M
[
1
],
K
[
0
],
K
[
1
]});
assign_default_strides
(
BLayout
{},
StridesB
,
{
N
[
0
],
N
[
1
],
K
[
0
],
K
[
1
]});
assign_default_strides
(
CDLayout
{},
StridesC
,
{
M
[
0
],
M
[
1
],
N
[
0
],
N
[
1
]});
assign_default_strides
(
CDLayout
{},
StridesD
,
{
M
[
0
],
M
[
1
],
N
[
0
],
N
[
1
]});
for
(
const
ck
::
index_t
init_method
:
init_methods
)
{
bool
pass
=
ck
::
profiler
::
profile_contraction_impl
<
ALayout
,
BLayout
,
CDLayout
,
DataType
,
ComputeDataType
,
DTupleDataType
,
CDElementOp
>
(
true
/*do_verification*/
,
init_method
,
false
/*do_logs*/
,
false
/*time_kernel*/
,
*
p_cd_element_op
,
dimension_params
.
M
,
dimension_params
.
N
,
dimension_params
.
K
,
StridesA
,
StridesB
,
StridesC
,
StridesD
);
EXPECT_TRUE
(
pass
);
}
}
}
};
template
<
typename
Tuple
>
class
TestContractionBilinear
:
public
TestContraction
<
Tuple
>
{
};
#define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op) \
std::tuple<Row, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Row, Col, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Col, Row, dt, tuple_dt, compute_dt, op>
using
BilinearKernelTypes
=
::
testing
::
Types
<
ALL_LAYOUT_COMBINATIONS
(
F32
,
ck
::
Tuple
<
F32
>
,
F32
,
Bilinear
),
ALL_LAYOUT_COMBINATIONS
(
F64
,
ck
::
Tuple
<
F64
>
,
F64
,
Bilinear
)
>
;
TYPED_TEST_SUITE
(
TestContractionBilinear
,
BilinearKernelTypes
);
TYPED_TEST
(
TestContractionBilinear
,
bilinear
)
{
this
->
p_cd_element_op
=
std
::
make_unique
<
Bilinear
>
(
1.
f
,
1.
f
);
this
->
Run
();
this
->
p_cd_element_op
=
std
::
make_unique
<
Bilinear
>
(
-
0.5
f
,
0.5
f
);
this
->
Run
();
}
test/complex_contraction_bilinear/test_complex_contraction_bilinear_interface.cpp
0 → 100755
View file @
e30e7a8c
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <stdexcept>
#include <vector>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"
#include "ck/library/utility/device_memory.hpp"
using
Pass
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F32
=
float
;
using
F64
=
double
;
template
<
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
CDEBlockTransferScalarPerVector
>
class
ContractionInstanceWrapper
{
public:
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
ck
::
index_t
NumDim
=
2
;
// clang-format off
using
ContractionDeviceInstance
=
ck
::
tensor_operation
::
device
::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
NumDim
,
NumDim
,
NumDim
,
F32
,
F32
,
F32
,
F32
,
ck
::
Tuple
<
F32
>
,
F32
,
Pass
,
Pass
,
Bilinear
,
GemmSpec
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
CDEBlockTransferScalarPerVector
,
F32
>
;
// clang-format on
bool
isSupported
(
std
::
vector
<
ck
::
index_t
>&
ADims
,
std
::
vector
<
ck
::
index_t
>&
BDims
,
std
::
vector
<
ck
::
index_t
>&
DDims
,
std
::
vector
<
ck
::
index_t
>&
EDims
,
std
::
vector
<
ck
::
index_t
>&
AStrides
,
std
::
vector
<
ck
::
index_t
>&
BStrides
,
std
::
vector
<
ck
::
index_t
>&
DStrides
,
std
::
vector
<
ck
::
index_t
>&
EStrides
)
const
{
auto
contraction
=
ContractionDeviceInstance
{};
auto
argument
=
contraction
.
MakeArgument
(
nullptr
,
nullptr
,
std
::
array
<
const
void
*
,
1
>
{
nullptr
},
nullptr
,
ADims
,
AStrides
,
BDims
,
BStrides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
DDims
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
DStrides
},
EDims
,
EStrides
,
Pass
{},
Pass
{},
Bilinear
{
1.
f
,
1.
f
});
return
contraction
.
IsSupportedArgument
(
argument
);
}
};
template
<
typename
DataTypeA
,
typename
DataTypeB
,
typename
DataTypeC
,
typename
DataTypeD
,
ck
::
index_t
NumDim
>
class
ContractionDeviceOpWrapper
{
protected:
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceContractionMultipleD
<
NumDim
,
NumDim
,
NumDim
,
DataTypeA
,
DataTypeB
,
ck
::
Tuple
<
DataTypeC
>
,
DataTypeD
,
Pass
,
Pass
,
Bilinear
>
;
public:
bool
IsSupportedInstance
(
std
::
vector
<
ck
::
index_t
>&
Dims
,
std
::
vector
<
ck
::
index_t
>&
Strides
)
const
{
bool
supported
=
false
;
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
op_ptr
:
op_ptrs
)
{
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
nullptr
,
nullptr
,
std
::
array
<
const
void
*
,
1
>
{
nullptr
},
nullptr
,
Dims
,
Strides
,
Dims
,
Strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
Dims
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
Strides
},
Dims
,
Strides
,
Pass
{},
Pass
{},
Bilinear
{
1.
f
,
1.
f
});
supported
=
supported
||
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
());
}
return
supported
;
}
};
TEST
(
TestContractionInterface
,
IncorrectNumDims
)
{
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Dims
=
{{
4
,
4
},
{
4
,
4
,
4
,
4
},
{
4
,
4
,
4
,
4
,
4
,
4
}};
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Strides
=
{{
1
,
1
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
,
1
}};
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
1
>
wrapper_1d
;
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
2
>
wrapper_2d
;
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
3
>
wrapper_3d
;
EXPECT_FALSE
(
wrapper_1d
.
IsSupportedInstance
(
Dims
[
0
],
Strides
[
0
]));
EXPECT_TRUE
(
wrapper_2d
.
IsSupportedInstance
(
Dims
[
1
],
Strides
[
1
]));
EXPECT_FALSE
(
wrapper_3d
.
IsSupportedInstance
(
Dims
[
2
],
Strides
[
2
]));
}
TEST
(
TestContractionInterface
,
IncorrectDataTypes
)
{
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Strides
=
{
64
,
16
,
4
,
1
};
ContractionDeviceOpWrapper
<
F32
,
F32
,
F64
,
F64
,
2
>
wrapper_1
;
ContractionDeviceOpWrapper
<
F64
,
F64
,
F32
,
F32
,
2
>
wrapper_2
;
EXPECT_FALSE
(
wrapper_1
.
IsSupportedInstance
(
Dims
,
Strides
));
EXPECT_FALSE
(
wrapper_2
.
IsSupportedInstance
(
Dims
,
Strides
));
}
TEST
(
TestContractionSupportedArgs
,
ABMemoryAccess
)
{
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Strides
=
{
64
,
16
,
4
,
1
};
std
::
vector
<
ck
::
index_t
>
StridesM1
=
{
4
,
1
,
64
,
16
};
std
::
vector
<
ck
::
index_t
>
StridesK1
=
{
64
,
16
,
4
,
1
};
std
::
vector
<
ck
::
index_t
>
InvalidStrides
=
{
4
,
4
,
4
,
4
};
// Memory access to A
ContractionInstanceWrapper
<
1
,
2
,
4
>
wrapperA1
;
ContractionInstanceWrapper
<
2
,
2
,
4
>
wrapperA2
;
EXPECT_FALSE
(
wrapperA1
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
InvalidStrides
,
Strides
,
Strides
,
Strides
));
EXPECT_FALSE
(
wrapperA2
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
InvalidStrides
,
Strides
,
Strides
,
Strides
));
EXPECT_TRUE
(
wrapperA1
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
StridesM1
,
Strides
,
Strides
,
Strides
));
EXPECT_TRUE
(
wrapperA2
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
StridesK1
,
Strides
,
Strides
,
Strides
));
// Memory access to B
ContractionInstanceWrapper
<
2
,
1
,
4
>
wrapperB1
;
ContractionInstanceWrapper
<
2
,
2
,
4
>
wrapperB2
;
EXPECT_FALSE
(
wrapperB1
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
InvalidStrides
,
Strides
,
Strides
));
EXPECT_FALSE
(
wrapperB2
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
InvalidStrides
,
Strides
,
Strides
));
EXPECT_TRUE
(
wrapperB1
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
StridesM1
,
Strides
,
Strides
));
EXPECT_TRUE
(
wrapperB2
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
StridesK1
,
Strides
,
Strides
));
}
TEST
(
TestContractionSupportedArgs
,
DEMemoryAccess
)
{
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Strides
=
{
64
,
16
,
4
,
1
};
std
::
vector
<
ck
::
index_t
>
InvalidStrides
=
{
64
,
16
,
1
,
4
};
ContractionInstanceWrapper
<
2
,
2
,
4
>
wrapper
;
// Memory access to D
EXPECT_FALSE
(
wrapper
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
Strides
,
InvalidStrides
,
Strides
));
EXPECT_TRUE
(
wrapper
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
Strides
,
Strides
,
Strides
));
// Memory access to E
EXPECT_FALSE
(
wrapper
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
Strides
,
Strides
,
InvalidStrides
));
EXPECT_TRUE
(
wrapper
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
Strides
,
Strides
,
Strides
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment