Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6d2d39ba
Commit
6d2d39ba
authored
May 12, 2023
by
Bartlomiej Kocot
Browse files
Extend test_contraction_interface
parent
1abe377b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
114 additions
and
53 deletions
+114
-53
test/contraction/test_contraction_interface.cpp
test/contraction/test_contraction_interface.cpp
+114
-53
No files found.
test/contraction/test_contraction_interface.cpp
View file @
6d2d39ba
...
...
@@ -8,6 +8,8 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"
...
...
@@ -16,15 +18,65 @@
using
Pass
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F32
=
float
;
using
F64
=
double
;
template
<
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
CDEBlockTransferScalarPerVector
>
class
ContractionInstanceWrapper
{
public:
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
ck
::
index_t
NumDim
=
2
;
// clang-format off
using
ContractionDeviceInstance
=
ck
::
tensor_operation
::
device
::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
NumDim
,
NumDim
,
NumDim
,
F32
,
F32
,
F32
,
F32
,
ck
::
Tuple
<
F32
>
,
F32
,
Pass
,
Pass
,
Bilinear
,
GemmSpec
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
CDEBlockTransferScalarPerVector
>
;
// clang-format on
bool
isSupported
(
std
::
vector
<
ck
::
index_t
>&
ADims
,
std
::
vector
<
ck
::
index_t
>&
BDims
,
std
::
vector
<
ck
::
index_t
>&
DDims
,
std
::
vector
<
ck
::
index_t
>&
EDims
,
std
::
vector
<
ck
::
index_t
>&
AStrides
,
std
::
vector
<
ck
::
index_t
>&
BStrides
,
std
::
vector
<
ck
::
index_t
>&
DStrides
,
std
::
vector
<
ck
::
index_t
>&
EStrides
)
const
{
auto
contraction
=
ContractionDeviceInstance
{};
auto
argument
=
contraction
.
MakeArgument
(
nullptr
,
nullptr
,
std
::
array
<
const
void
*
,
1
>
{
nullptr
},
nullptr
,
ADims
,
AStrides
,
BDims
,
BStrides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
DDims
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
DStrides
},
EDims
,
EStrides
,
Pass
{},
Pass
{},
Bilinear
{
1.
f
,
1.
f
});
return
contraction
.
IsSupportedArgument
(
argument
);
}
};
template
<
typename
DataTypeA
,
typename
DataTypeB
,
typename
DataTypeC
,
typename
DataTypeD
,
ck
::
index_t
NumDim
>
class
ContractionDeviceWrapper
class
ContractionDevice
Op
Wrapper
{
protected:
...
...
@@ -40,26 +92,8 @@ class ContractionDeviceWrapper
Bilinear
>
;
public:
ContractionDeviceWrapper
(
std
::
vector
<
ck
::
index_t
>&
Dims
,
std
::
vector
<
ck
::
index_t
>&
Strides
)
:
InputDims_
(
Dims
),
OutputDims_
(
Dims
),
InputStrides_
(
Strides
),
OutputStrides_
(
Strides
)
{
}
ContractionDeviceWrapper
(
std
::
vector
<
ck
::
index_t
>&
InDims
,
std
::
vector
<
ck
::
index_t
>&
OutDims
,
std
::
vector
<
ck
::
index_t
>&
InStrides
,
std
::
vector
<
ck
::
index_t
>&
OutStrides
)
:
InputDims_
(
InDims
),
OutputDims_
(
OutDims
),
InputStrides_
(
InStrides
),
OutputStrides_
(
OutStrides
)
{
}
std
::
vector
<
ck
::
index_t
>&
InputDims_
;
std
::
vector
<
ck
::
index_t
>&
OutputDims_
;
std
::
vector
<
ck
::
index_t
>&
InputStrides_
;
std
::
vector
<
ck
::
index_t
>&
OutputStrides_
;
bool
IsSupported
()
const
bool
IsSupportedInstance
(
std
::
vector
<
ck
::
index_t
>&
Dims
,
std
::
vector
<
ck
::
index_t
>&
Strides
)
const
{
bool
supported
=
false
;
...
...
@@ -73,14 +107,14 @@ class ContractionDeviceWrapper
nullptr
,
std
::
array
<
const
void
*
,
1
>
{
nullptr
},
nullptr
,
InputStrides_
,
Input
Strides
_
,
InputStrides_
,
Input
Strides
_
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
InputStrides_
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
Input
Strides
_
},
Output
Dims
_
,
Output
Strides
_
,
Dims
,
Strides
,
Dims
,
Strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
Dims
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
Strides
},
Dims
,
Strides
,
Pass
{},
Pass
{},
Bilinear
{
1.
f
,
1.
f
});
...
...
@@ -95,40 +129,67 @@ TEST(TestContractionInterface, IncorrectNumDims)
{
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Dims
=
{{
4
,
4
},
{
4
,
4
,
4
,
4
},
{
4
,
4
,
4
,
4
,
4
,
4
}};
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Strides
=
{{
1
,
1
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
,
1
}};
ContractionDeviceWrapper
<
F32
,
F32
,
F32
,
F32
,
1
>
wrapper_1d
(
Dims
[
0
],
Strides
[
0
])
;
ContractionDeviceWrapper
<
F32
,
F32
,
F32
,
F32
,
2
>
wrapper_2d
(
Dims
[
1
],
Strides
[
1
])
;
ContractionDeviceWrapper
<
F32
,
F32
,
F32
,
F32
,
3
>
wrapper_3d
(
Dims
[
2
],
Strides
[
2
])
;
EXPECT_FALSE
(
wrapper_1d
.
IsSupported
(
));
EXPECT_TRUE
(
wrapper_2d
.
IsSupported
(
));
EXPECT_FALSE
(
wrapper_3d
.
IsSupported
(
));
ContractionDevice
Op
Wrapper
<
F32
,
F32
,
F32
,
F32
,
1
>
wrapper_1d
;
ContractionDevice
Op
Wrapper
<
F32
,
F32
,
F32
,
F32
,
2
>
wrapper_2d
;
ContractionDevice
Op
Wrapper
<
F32
,
F32
,
F32
,
F32
,
3
>
wrapper_3d
;
EXPECT_FALSE
(
wrapper_1d
.
IsSupported
Instance
(
Dims
[
0
],
Strides
[
0
]
));
EXPECT_TRUE
(
wrapper_2d
.
IsSupported
Instance
(
Dims
[
1
],
Strides
[
1
]
));
EXPECT_FALSE
(
wrapper_3d
.
IsSupported
Instance
(
Dims
[
2
],
Strides
[
2
]
));
}
TEST
(
TestContractionInterface
,
IncorrectDataTypes
)
{
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Strides
=
{
64
,
16
,
4
,
1
};
ContractionDeviceWrapper
<
F32
,
F32
,
F64
,
F64
,
2
>
wrapper_1
(
Dims
,
Strides
)
;
ContractionDeviceWrapper
<
F64
,
F64
,
F32
,
F32
,
2
>
wrapper_2
(
Dims
,
Strides
)
;
EXPECT_FALSE
(
wrapper_1
.
IsSupported
(
));
EXPECT_FALSE
(
wrapper_2
.
IsSupported
(
));
ContractionDevice
Op
Wrapper
<
F32
,
F32
,
F64
,
F64
,
2
>
wrapper_1
;
ContractionDevice
Op
Wrapper
<
F64
,
F64
,
F32
,
F32
,
2
>
wrapper_2
;
EXPECT_FALSE
(
wrapper_1
.
IsSupported
Instance
(
Dims
,
Strides
));
EXPECT_FALSE
(
wrapper_2
.
IsSupported
Instance
(
Dims
,
Strides
));
}
TEST
(
TestContraction
Interface
,
GridwiseGemm
)
TEST
(
TestContraction
SupportedArgs
,
ABMemoryAccess
)
{
std
::
vector
<
ck
::
index_t
>
InDims
=
{
1
,
2
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
InStrides
=
{
24
,
12
,
4
,
1
};
std
::
vector
<
ck
::
index_t
>
OutDims
=
{
4
,
3
,
2
,
1
};
std
::
vector
<
ck
::
index_t
>
OutStrides
=
{
6
,
2
,
1
,
1
};
ContractionDeviceWrapper
<
F32
,
F32
,
F32
,
F32
,
2
>
wrapper
(
InDims
,
OutDims
,
InStrides
,
OutStrides
);
EXPECT_FALSE
(
wrapper
.
IsSupported
());
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Strides
=
{
64
,
16
,
4
,
1
};
std
::
vector
<
ck
::
index_t
>
StridesM1
=
{
4
,
1
,
64
,
16
};
std
::
vector
<
ck
::
index_t
>
StridesK1
=
{
64
,
16
,
4
,
1
};
std
::
vector
<
ck
::
index_t
>
InvalidStrides
=
{
4
,
4
,
4
,
4
};
// Memory access to A
ContractionInstanceWrapper
<
1
,
2
,
4
>
wrapperA1
;
ContractionInstanceWrapper
<
2
,
2
,
4
>
wrapperA2
;
EXPECT_FALSE
(
wrapperA1
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
InvalidStrides
,
Strides
,
Strides
,
Strides
));
EXPECT_FALSE
(
wrapperA2
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
InvalidStrides
,
Strides
,
Strides
,
Strides
));
EXPECT_TRUE
(
wrapperA1
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
StridesM1
,
Strides
,
Strides
,
Strides
));
EXPECT_TRUE
(
wrapperA2
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
StridesK1
,
Strides
,
Strides
,
Strides
));
// Memory access to B
ContractionInstanceWrapper
<
2
,
1
,
4
>
wrapperB1
;
ContractionInstanceWrapper
<
2
,
2
,
4
>
wrapperB2
;
EXPECT_FALSE
(
wrapperB1
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
InvalidStrides
,
Strides
,
Strides
));
EXPECT_FALSE
(
wrapperB2
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
InvalidStrides
,
Strides
,
Strides
));
EXPECT_TRUE
(
wrapperB1
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
StridesM1
,
Strides
,
Strides
));
EXPECT_TRUE
(
wrapperB2
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
StridesK1
,
Strides
,
Strides
));
}
TEST
(
TestContraction
Interface
,
MemoryAccess
)
TEST
(
TestContraction
SupportedArgs
,
DE
MemoryAccess
)
{
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Strides
=
{
4
,
16
,
64
,
256
};
ContractionDeviceWrapper
<
F32
,
F32
,
F32
,
F32
,
2
>
wrapper
(
Dims
,
Strides
);
EXPECT_FALSE
(
wrapper
.
IsSupported
());
std
::
vector
<
ck
::
index_t
>
Strides
=
{
64
,
16
,
4
,
1
};
std
::
vector
<
ck
::
index_t
>
InvalidStrides
=
{
64
,
16
,
1
,
4
};
ContractionInstanceWrapper
<
2
,
2
,
4
>
wrapper
;
// Memory access to D
EXPECT_FALSE
(
wrapper
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
Strides
,
InvalidStrides
,
Strides
));
EXPECT_TRUE
(
wrapper
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
Strides
,
Strides
,
Strides
));
// Memory access to E
EXPECT_FALSE
(
wrapper
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
Strides
,
Strides
,
InvalidStrides
));
EXPECT_TRUE
(
wrapper
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
Strides
,
Strides
,
Strides
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment