Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
75cf3655
Commit
75cf3655
authored
Dec 15, 2023
by
muozturk
Browse files
scale
parent
017fb2eb
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1223 additions
and
0 deletions
+1223
-0
example/65_complex_contraction_scale/CMakeLists.txt
example/65_complex_contraction_scale/CMakeLists.txt
+2
-0
example/65_complex_contraction_scale/common_instances.hpp
example/65_complex_contraction_scale/common_instances.hpp
+196
-0
example/65_complex_contraction_scale/complex_contraction_scale_xdl_fp32.cpp
..._contraction_scale/complex_contraction_scale_xdl_fp32.cpp
+101
-0
example/65_complex_contraction_scale/complex_contraction_scale_xdl_fp64.cpp
..._contraction_scale/complex_contraction_scale_xdl_fp64.cpp
+101
-0
example/65_complex_contraction_scale/run_complex_contraction_scale_example.inc
...ntraction_scale/run_complex_contraction_scale_example.inc
+483
-0
test/complex_contraction_scale/CMakeLists.txt
test/complex_contraction_scale/CMakeLists.txt
+13
-0
test/complex_contraction_scale/test_complex_contraction_scale.cpp
...plex_contraction_scale/test_complex_contraction_scale.cpp
+127
-0
test/complex_contraction_scale/test_complex_contraction_scale_interface.cpp
...action_scale/test_complex_contraction_scale_interface.cpp
+200
-0
No files found.
example/65_complex_contraction_scale/CMakeLists.txt
0 → 100755
View file @
75cf3655
add_example_executable
(
example_complex_contraction_scale_xdl_fp32 complex_contraction_scale_xdl_fp32.cpp
)
add_example_executable
(
example_complex_contraction_scale_xdl_fp64 complex_contraction_scale_xdl_fp64.cpp
)
example/65_complex_contraction_scale/common_instances.hpp
0 → 100644
View file @
75cf3655
This diff is collapsed.
Click to expand it.
example/65_complex_contraction_scale/complex_contraction_scale_xdl_fp32.cpp
0 → 100755
View file @
75cf3655
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "common_instances.hpp"
using
ADataType
=
F32
;
using
BDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F32
;
using
ComputeDataType
=
F32
;
static
constexpr
ck
::
index_t
NumDimM
=
2
;
static
constexpr
ck
::
index_t
NumDimN
=
2
;
static
constexpr
ck
::
index_t
NumDimK
=
2
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CDEElementOp
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
CDEElementOp_Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
DeviceOpInstanceKKN
=
DeviceOpInstanceKK_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstanceKKN_Scale
=
DeviceOpInstanceKK_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp_Scale
>
;
using
DeviceOpInstanceKNN
=
DeviceOpInstanceKN_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstanceMKN
=
DeviceOpInstanceMK_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstanceMNN
=
DeviceOpInstanceMN_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstance
=
DeviceOpInstanceKKN
;
using
DeviceOpInstance_Scale
=
DeviceOpInstanceKKN_Scale
;
#include "run_complex_contraction_scale_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_complex_contraction_scale_example
(
argc
,
argv
);
}
example/65_complex_contraction_scale/complex_contraction_scale_xdl_fp64.cpp
0 → 100755
View file @
75cf3655
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "common_instances.hpp"
using
ADataType
=
F64
;
using
BDataType
=
F64
;
using
AccDataType
=
F64
;
using
CShuffleDataType
=
F64
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F64
;
using
ComputeDataType
=
F64
;
static
constexpr
ck
::
index_t
NumDimM
=
2
;
static
constexpr
ck
::
index_t
NumDimN
=
2
;
static
constexpr
ck
::
index_t
NumDimK
=
2
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CDEElementOp
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
CDEElementOp_Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
DeviceOpInstanceKKN
=
DeviceOpInstanceKK_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstanceKKN_Scale
=
DeviceOpInstanceKK_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp_Scale
>
;
using
DeviceOpInstanceKNN
=
DeviceOpInstanceKN_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstanceMKN
=
DeviceOpInstanceMK_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstanceMNN
=
DeviceOpInstanceMN_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstance
=
DeviceOpInstanceKKN
;
using
DeviceOpInstance_Scale
=
DeviceOpInstanceKKN_Scale
;
#include "run_complex_contraction_scale_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_complex_contraction_scale_example
(
argc
,
argv
);
}
example/65_complex_contraction_scale/run_complex_contraction_scale_example.inc
0 → 100755
View file @
75cf3655
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <string>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
int
run_complex_contraction_bilinear_example
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// A[M0, M1, K0, K1]
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
{
524288
,
4096
,
128
,
1
};
// B[N0, N1, K0, K1]
std
::
vector
<
ck
::
index_t
>
b_ns_ks_lengths
{
32
,
64
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
b_ns_ks_strides
{
524288
,
4096
,
128
,
1
};
// D[M0, M1, N0, N1]
std
::
vector
<
ck
::
index_t
>
d_ms_ns_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
d_ms_ns_strides
{
524288
,
4096
,
128
,
1
};
// E[M0, M1, N0, N1]
std
::
vector
<
ck
::
index_t
>
e_ms_ns_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
e_ms_ns_strides
{
524288
,
4096
,
128
,
1
};
float
scale
=
1.
f
;
float
alpha
=
1.
f
;
float
beta
=
1.
f
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
28
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
const
ck
::
index_t
M0
=
std
::
stoi
(
argv
[
4
]);
const
ck
::
index_t
M1
=
std
::
stoi
(
argv
[
5
]);
const
ck
::
index_t
N0
=
std
::
stoi
(
argv
[
6
]);
const
ck
::
index_t
N1
=
std
::
stoi
(
argv
[
7
]);
const
ck
::
index_t
K0
=
std
::
stoi
(
argv
[
8
]);
const
ck
::
index_t
K1
=
std
::
stoi
(
argv
[
9
]);
a_ms_ks_lengths
=
{
M0
,
M1
,
K0
,
K1
};
a_ms_ks_strides
=
{
std
::
stoi
(
argv
[
10
]),
std
::
stoi
(
argv
[
11
]),
std
::
stoi
(
argv
[
12
]),
std
::
stoi
(
argv
[
13
])};
b_ns_ks_lengths
=
{
N0
,
N1
,
K0
,
K1
};
b_ns_ks_strides
=
{
std
::
stoi
(
argv
[
14
]),
std
::
stoi
(
argv
[
15
]),
std
::
stoi
(
argv
[
16
]),
std
::
stoi
(
argv
[
17
])};
d_ms_ns_lengths
=
{
M0
,
M1
,
N0
,
N1
};
d_ms_ns_strides
=
{
std
::
stoi
(
argv
[
18
]),
std
::
stoi
(
argv
[
19
]),
std
::
stoi
(
argv
[
20
]),
std
::
stoi
(
argv
[
21
])};
e_ms_ns_lengths
=
{
M0
,
M1
,
N0
,
N1
};
e_ms_ns_strides
=
{
std
::
stoi
(
argv
[
22
]),
std
::
stoi
(
argv
[
23
]),
std
::
stoi
(
argv
[
24
]),
std
::
stoi
(
argv
[
25
])};
alpha
=
std
::
stof
(
argv
[
26
]);
beta
=
std
::
stof
(
argv
[
27
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M0, M1, N0, N1, K0, K1
\n
"
);
printf
(
"arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1
\n
"
);
printf
(
"arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1
\n
"
);
printf
(
"arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1
\n
"
);
printf
(
"arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1
\n
"
);
printf
(
"arg26 to 27: alpha, beta
\n
"
);
exit
(
0
);
}
// For Real Part of Complex Tensor
Tensor
<
ADataType
>
a_ms_ks_re
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
Tensor
<
BDataType
>
b_ns_ks_re
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
Tensor
<
EDataType
>
e_ms_ns_host_result_re
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_device_result_re
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
// For Imaginary Part of Complex Tensor
Tensor
<
ADataType
>
a_ms_ks_img
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
Tensor
<
BDataType
>
b_ns_ks_img
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
Tensor
<
EDataType
>
e_ms_ns_host_result_img
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_device_result_img
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
// // Intermediate E tensor Definition
// Tensor<EDataType> e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides);
// Tensor<EDataType> e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides);
std
::
cout
<<
"a_ms_ks_re: "
<<
a_ms_ks_re
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_ns_ks_re: "
<<
b_ns_ks_re
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_ms_ns_re: "
<<
e_ms_ns_host_result_re
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_ms_ks_img: "
<<
a_ms_ks_img
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_ns_ks_img: "
<<
b_ns_ks_img
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_ms_ns_img: "
<<
e_ms_ns_host_result_img
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_ms_ks_re
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_ns_ks_re
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
a_ms_ks_img
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_ns_ks_img
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
default
:
a_ms_ks_re
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_ns_ks_re
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
a_ms_ks_img
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_ns_ks_img
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
}
DeviceMem
a_device_buf_re
(
sizeof
(
ADataType
)
*
a_ms_ks_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf_re
(
sizeof
(
BDataType
)
*
b_ns_ks_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_re
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf_img
(
sizeof
(
ADataType
)
*
a_ms_ks_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf_img
(
sizeof
(
BDataType
)
*
b_ns_ks_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_img
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
// // Intermediate Value For E Real and Img
// DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
// DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
a_device_buf_re
.
ToDevice
(
a_ms_ks_re
.
mData
.
data
());
b_device_buf_re
.
ToDevice
(
b_ns_ks_re
.
mData
.
data
());
a_device_buf_img
.
ToDevice
(
a_ms_ks_img
.
mData
.
data
());
b_device_buf_img
.
ToDevice
(
b_ns_ks_img
.
mData
.
data
());
// set zero
e_device_buf_re
.
SetZero
();
e_device_buf_img
.
SetZero
();
// // set zero for intermediate values
// e_device_buf_re1.SetZero();
// e_device_buf_img1.SetZero();
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op_scale
=
CDEElementOp_Scale
{
scale
};
// device operation
// C_real = A_real * B_real
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument_re1
=
op
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
// std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
std
::
array
<
const
void
*
,
0
>
{},
e_device_buf_re
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
0
>
{},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
0
>
{},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op_scale
);
if
(
!
op
.
IsSupportedArgument
(
argument_re1
))
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
float
ave_time_re1
=
invoker
.
Run
(
argument_re1
,
StreamConfig
{
nullptr
,
time_kernel
});
alpha
=
-
1.
f
*
scale
;
beta
=
1.
f
;
a_element_op
=
AElementOp
{};
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
// device operation
// For real Intermediate Value re_2
auto
argument_re2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_img
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_re
.
GetDeviceBuffer
()},
e_device_buf_re
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument_re2
))
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
float
ave_time_re2
=
invoker
.
Run
(
argument_re2
,
StreamConfig
{
nullptr
,
time_kernel
});
// scale = 1.f ;
// a_element_op = AElementOp{};
// b_element_op = BElementOp{};
// cde_element_op = CDEElementOp{alpha, beta};
auto
argument_img1
=
op
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
b_device_buf_img
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
0
>
{}},
e_device_buf_img
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
0
>
{},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
0
>
{},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op_scale
);
if
(
!
op
.
IsSupportedArgument
(
argument_img1
))
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
float
ave_time_img1
=
invoker
.
Run
(
argument_img1
,
StreamConfig
{
nullptr
,
time_kernel
});
alpha
=
1.
f
*
scale
;
beta
=
1.
f
;
auto
argument_img2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_img
.
GetDeviceBuffer
()},
e_device_buf_img
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument_img2
))
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
float
ave_time_img2
=
invoker
.
Run
(
argument_img2
,
StreamConfig
{
nullptr
,
time_kernel
});
ck
::
index_t
M
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
e_ms_ns_lengths
.
begin
(),
NumDimM
,
1
,
std
::
multiplies
<>
{});
ck
::
index_t
N
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
NumDimN
,
1
,
std
::
multiplies
<>
{});
ck
::
index_t
K
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
a_ms_ks_lengths
.
begin
()
+
NumDimM
,
NumDimK
,
1
,
std
::
multiplies
<>
{});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
*
2
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
DDataType
)
*
M
*
N
+
sizeof
(
EDataType
)
*
M
*
N
*
2
;
float
ave_time
=
ave_time_img2
+
ave_time_img1
+
ave_time_re2
+
ave_time_re1
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op
.
GetTypeString
()
<<
std
::
endl
;
e_device_buf_re
.
FromDevice
(
e_ms_ns_device_result_re
.
mData
.
data
());
e_device_buf_img
.
FromDevice
(
e_ms_ns_device_result_img
.
mData
.
data
());
auto
isRealOk
=
0
;
auto
isImgOk
=
0
;
if
(
do_verification
)
{
// Real Part Verification
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result_re
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result_re1
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
using
ReferenceOpInstance
=
ck
::
tensor_operation
::
host
::
ReferenceContraction_M2_N2_K2
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
CShuffleDataType
,
AccDataType
,
F32
,
AElementOp
,
BElementOp
>
;
auto
ref_op
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_op
.
MakeInvoker
();
auto
ref_argument_re
=
ref_op
.
MakeArgument
(
a_ms_ks_re
,
b_ns_ks_re
,
c_ms_ns_host_result_re
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_re
);
// alpha = 1.f;
// beta = 1.f;
// cde_element_op = CDEElementOp{alpha, beta};
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op_scale
(
e_ms_ns_host_result_re
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result_re
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
alpha
=
1.
f
*
scale
;
beta
=
-
1.
f
;
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
auto
ref_argument_re1
=
ref_op
.
MakeArgument
(
a_ms_ks_img
,
b_ns_ks_img
,
c_ms_ns_host_result_re1
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_re1
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result_re
(
m0
,
m1
,
n0
,
n1
),
e_ms_ns_host_result_re
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result_re1
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
isRealOk
=
ck
::
utils
::
check_err
(
e_ms_ns_device_result_re
,
e_ms_ns_host_result_re
)
?
0
:
1
;
// Img Part Verification
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result_img
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result_img1
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
auto
ref_argument_img
=
ref_op
.
MakeArgument
(
a_ms_ks_re
,
b_ns_ks_img
,
c_ms_ns_host_result_img
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_img
);
// alpha = 1.f;
// beta = 1.f;
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op_scale
(
e_ms_ns_host_result_img
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result_img
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
alpha
=
1.
f
*
scale
;
beta
=
-
1.
f
;
auto
ref_argument_img1
=
ref_op
.
MakeArgument
(
a_ms_ks_img
,
b_ns_ks_re
,
c_ms_ns_host_result_img1
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_img1
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result_img
(
m0
,
m1
,
n0
,
n1
),
e_ms_ns_host_result_img
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result_img1
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
isImgOk
=
ck
::
utils
::
check_err
(
e_ms_ns_device_result_re
,
e_ms_ns_host_result_re
)
?
0
:
1
;
return
(
isRealOk
&&
isImgOk
);
}
return
0
;
}
test/complex_contraction_scale/CMakeLists.txt
0 → 100755
View file @
75cf3655
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
((
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
)
OR NOT DEFINED DTYPES
)
add_gtest_executable
(
test_complex_contraction_scale test_complex_contraction_scale.cpp
)
target_link_libraries
(
test_complex_contraction_scale PRIVATE utility device_contraction_scale_instance
)
add_gtest_executable
(
test_complex_contraction_scale_interface test_complex_contraction_scale_interface.cpp
)
target_link_libraries
(
test_complex_contraction_scale_interface PRIVATE utility device_contraction_scale_instance
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
test/complex_contraction_scale/test_complex_contraction_scale.cpp
0 → 100755
View file @
75cf3655
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <memory>
#include <initializer_list>
#include <vector>
#include <tuple>
#include <gtest/gtest.h>
#include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp"
using
F32
=
float
;
using
F64
=
double
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
struct
Dimensions
{
std
::
vector
<
ck
::
index_t
>
M
;
std
::
vector
<
ck
::
index_t
>
N
;
std
::
vector
<
ck
::
index_t
>
K
;
};
template
<
typename
Tuple
>
class
TestContraction
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CDLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
DataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
DTupleDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
ComputeDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDElementOp
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
std
::
vector
<
Dimensions
>
dimension_list
=
{{{
32
,
32
},
{
32
,
32
},
{
32
,
32
}},
{{
16
,
16
},
{
32
,
32
},
{
16
,
16
}}};
std
::
vector
<
ck
::
index_t
>
init_methods
=
{
1
,
2
};
std
::
unique_ptr
<
CDElementOp
>
p_cd_element_op
;
void
Run
()
{
for
(
auto
&
dimension_params
:
dimension_list
)
{
std
::
vector
<
ck
::
index_t
>
StridesA
;
std
::
vector
<
ck
::
index_t
>
StridesB
;
std
::
vector
<
ck
::
index_t
>
StridesC
;
std
::
vector
<
ck
::
index_t
>
StridesD
;
const
auto
&
M
=
dimension_params
.
M
;
const
auto
&
N
=
dimension_params
.
N
;
const
auto
&
K
=
dimension_params
.
K
;
assign_default_strides
(
ALayout
{},
StridesA
,
{
M
[
0
],
M
[
1
],
K
[
0
],
K
[
1
]});
assign_default_strides
(
BLayout
{},
StridesB
,
{
N
[
0
],
N
[
1
],
K
[
0
],
K
[
1
]});
assign_default_strides
(
CDLayout
{},
StridesC
,
{
M
[
0
],
M
[
1
],
N
[
0
],
N
[
1
]});
assign_default_strides
(
CDLayout
{},
StridesD
,
{
M
[
0
],
M
[
1
],
N
[
0
],
N
[
1
]});
for
(
const
ck
::
index_t
init_method
:
init_methods
)
{
bool
pass
=
ck
::
profiler
::
profile_contraction_impl
<
ALayout
,
BLayout
,
CDLayout
,
DataType
,
ComputeDataType
,
DTupleDataType
,
CDElementOp
>
(
true
/*do_verification*/
,
init_method
,
false
/*do_logs*/
,
false
/*time_kernel*/
,
*
p_cd_element_op
,
dimension_params
.
M
,
dimension_params
.
N
,
dimension_params
.
K
,
StridesA
,
StridesB
,
StridesC
,
StridesD
);
EXPECT_TRUE
(
pass
);
}
}
}
};
template
<
typename
Tuple
>
class
TestContractionScale
:
public
TestContraction
<
Tuple
>
{
};
#define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op) \
std::tuple<Row, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Row, Col, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Col, Row, dt, tuple_dt, compute_dt, op>
using
ScaleKernelTypes
=
::
testing
::
Types
<
ALL_LAYOUT_COMBINATIONS
(
F32
,
ck
::
Tuple
<>
,
F32
,
Scale
),
ALL_LAYOUT_COMBINATIONS
(
F64
,
ck
::
Tuple
<>
,
F64
,
Scale
)
>
;
TYPED_TEST_SUITE
(
TestContractionScale
,
ScaleKernelTypes
);
TYPED_TEST
(
TestContractionScale
,
scale
)
{
this
->
p_cd_element_op
=
std
::
make_unique
<
Scale
>
(
1.
f
);
this
->
Run
();
this
->
p_cd_element_op
=
std
::
make_unique
<
Scale
>
(
0.5
f
);
this
->
Run
();
}
test/complex_contraction_scale/test_complex_contraction_scale_interface.cpp
0 → 100755
View file @
75cf3655
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <stdexcept>
#include <vector>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction_scale.hpp"
#include "ck/library/utility/device_memory.hpp"
using
Pass
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F32
=
float
;
using
F64
=
double
;
template
<
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
CDEBlockTransferScalarPerVector
>
class
ContractionInstanceWrapper
{
public:
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
ck
::
index_t
NumDim
=
2
;
// clang-format off
using
ContractionDeviceInstance
=
ck
::
tensor_operation
::
device
::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
NumDim
,
NumDim
,
NumDim
,
F32
,
F32
,
F32
,
F32
,
ck
::
Tuple
<
F32
>
,
F32
,
Pass
,
Pass
,
Scale
,
GemmSpec
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
CDEBlockTransferScalarPerVector
,
F32
>
;
// clang-format on
bool
isSupported
(
std
::
vector
<
ck
::
index_t
>&
ADims
,
std
::
vector
<
ck
::
index_t
>&
BDims
,
std
::
vector
<
ck
::
index_t
>&
DDims
,
std
::
vector
<
ck
::
index_t
>&
EDims
,
std
::
vector
<
ck
::
index_t
>&
AStrides
,
std
::
vector
<
ck
::
index_t
>&
BStrides
,
std
::
vector
<
ck
::
index_t
>&
DStrides
,
std
::
vector
<
ck
::
index_t
>&
EStrides
)
const
{
auto
contraction
=
ContractionDeviceInstance
{};
auto
argument
=
contraction
.
MakeArgument
(
nullptr
,
nullptr
,
std
::
array
<
const
void
*
,
1
>
{
nullptr
},
// std::array<const void*, 0>{},
nullptr
,
ADims
,
AStrides
,
BDims
,
BStrides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
DDims
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
DStrides
},
// std::array<std::vector<ck::index_t>, 0>{},
// std::array<std::vector<ck::index_t>, 0>{},
EDims
,
EStrides
,
Pass
{},
Pass
{},
Scale
{
1.
f
});
return
contraction
.
IsSupportedArgument
(
argument
);
}
};
template
<
typename
DataTypeA
,
typename
DataTypeB
,
typename
DataTypeC
,
typename
DataTypeD
,
ck
::
index_t
NumDim
>
class
ContractionDeviceOpWrapper
{
protected:
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceContractionMultipleD
<
NumDim
,
NumDim
,
NumDim
,
DataTypeA
,
DataTypeB
,
ck
::
Tuple
<
DataTypeC
>
,
DataTypeD
,
Pass
,
Pass
,
Scale
>
;
public:
bool
IsSupportedInstance
(
std
::
vector
<
ck
::
index_t
>&
Dims
,
std
::
vector
<
ck
::
index_t
>&
Strides
)
const
{
bool
supported
=
false
;
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
op_ptr
:
op_ptrs
)
{
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
nullptr
,
nullptr
,
std
::
array
<
const
void
*
,
1
>
{
nullptr
},
nullptr
,
Dims
,
Strides
,
Dims
,
Strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
Dims
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
Strides
},
Dims
,
Strides
,
Pass
{},
Pass
{},
Scale
{
1.
f
});
supported
=
supported
||
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
());
}
return
supported
;
}
};
TEST
(
TestContractionInterface
,
IncorrectNumDims
)
{
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Dims
=
{{
4
,
4
},
{
4
,
4
,
4
,
4
},
{
4
,
4
,
4
,
4
,
4
,
4
}};
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Strides
=
{{
1
,
1
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
,
1
}};
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
1
>
wrapper_1d
;
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
2
>
wrapper_2d
;
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
3
>
wrapper_3d
;
EXPECT_FALSE
(
wrapper_1d
.
IsSupportedInstance
(
Dims
[
0
],
Strides
[
0
]));
EXPECT_TRUE
(
wrapper_2d
.
IsSupportedInstance
(
Dims
[
1
],
Strides
[
1
]));
EXPECT_FALSE
(
wrapper_3d
.
IsSupportedInstance
(
Dims
[
2
],
Strides
[
2
]));
}
TEST
(
TestContractionInterface
,
IncorrectDataTypes
)
{
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Strides
=
{
64
,
16
,
4
,
1
};
ContractionDeviceOpWrapper
<
F32
,
F32
,
F64
,
F64
,
2
>
wrapper_1
;
ContractionDeviceOpWrapper
<
F64
,
F64
,
F32
,
F32
,
2
>
wrapper_2
;
EXPECT_FALSE
(
wrapper_1
.
IsSupportedInstance
(
Dims
,
Strides
));
EXPECT_FALSE
(
wrapper_2
.
IsSupportedInstance
(
Dims
,
Strides
));
}
TEST
(
TestContractionInterface
,
ABMemoryAccess
)
{
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Strides
=
{
64
,
16
,
4
,
1
};
std
::
vector
<
ck
::
index_t
>
StridesM1
=
{
4
,
1
,
64
,
16
};
std
::
vector
<
ck
::
index_t
>
StridesK1
=
{
64
,
16
,
4
,
1
};
std
::
vector
<
ck
::
index_t
>
InvalidStrides
=
{
4
,
4
,
4
,
4
};
// Memory access to A
ContractionInstanceWrapper
<
1
,
2
,
4
>
wrapperA1
;
ContractionInstanceWrapper
<
2
,
2
,
4
>
wrapperA2
;
EXPECT_FALSE
(
wrapperA1
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
InvalidStrides
,
Strides
,
Strides
,
Strides
));
EXPECT_FALSE
(
wrapperA2
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
InvalidStrides
,
Strides
,
Strides
,
Strides
));
EXPECT_TRUE
(
wrapperA1
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
StridesM1
,
Strides
,
Strides
,
Strides
));
EXPECT_TRUE
(
wrapperA2
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
StridesK1
,
Strides
,
Strides
,
Strides
));
// Memory access to B
ContractionInstanceWrapper
<
2
,
1
,
4
>
wrapperB1
;
ContractionInstanceWrapper
<
2
,
2
,
4
>
wrapperB2
;
EXPECT_FALSE
(
wrapperB1
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
InvalidStrides
,
Strides
,
Strides
));
EXPECT_FALSE
(
wrapperB2
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
InvalidStrides
,
Strides
,
Strides
));
EXPECT_TRUE
(
wrapperB1
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
StridesM1
,
Strides
,
Strides
));
EXPECT_TRUE
(
wrapperB2
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
StridesK1
,
Strides
,
Strides
));
}
TEST
(
TestContractionSupportedArgs
,
DEMemoryAccess
)
{
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Strides
=
{
64
,
16
,
4
,
1
};
std
::
vector
<
ck
::
index_t
>
InvalidStrides
=
{
64
,
16
,
1
,
4
};
ContractionInstanceWrapper
<
2
,
2
,
4
>
wrapper
;
// Memory access to D
EXPECT_FALSE
(
wrapper
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
Strides
,
InvalidStrides
,
Strides
));
EXPECT_TRUE
(
wrapper
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
Strides
,
Strides
,
Strides
));
// Memory access to E
EXPECT_FALSE
(
wrapper
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
Strides
,
Strides
,
InvalidStrides
));
EXPECT_TRUE
(
wrapper
.
isSupported
(
Dims
,
Dims
,
Dims
,
Dims
,
Strides
,
Strides
,
Strides
,
Strides
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment