Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6c2b74de
Commit
6c2b74de
authored
Jul 05, 2022
by
Chao Liu
Browse files
change assumed virtual layout of contraction; add client example
parent
f6c42793
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
655 additions
and
224 deletions
+655
-224
client_example/05_contraction_bilinear/CMakeLists.txt
client_example/05_contraction_bilinear/CMakeLists.txt
+2
-0
client_example/05_contraction_bilinear/contraction_bilinear.cpp
..._example/05_contraction_bilinear/contraction_bilinear.cpp
+241
-0
client_example/CMakeLists.txt
client_example/CMakeLists.txt
+1
-0
example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp
example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp
+14
-14
example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp
example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp
+12
-12
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
..._gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
+15
-15
example/24_contraction_bilinear/contraction_bilinear_xdl_fp32.cpp
...24_contraction_bilinear/contraction_bilinear_xdl_fp32.cpp
+101
-62
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp
...or_operation/gpu/device/device_contraction_multiple_d.hpp
+12
-12
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp
...gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp
+119
-101
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+2
-0
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
...ry/tensor_operation_instance/gpu/contraction_bilinear.hpp
+128
-0
library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
..._m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
..._m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
..._m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp
..._m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp
+2
-2
No files found.
client_example/05_contraction_bilinear/CMakeLists.txt
0 → 100644
View file @
6c2b74de
add_executable
(
client_contraction_bilinear contraction_bilinear.cpp
)
target_link_libraries
(
client_contraction_bilinear PRIVATE composable_kernel::device_operations
)
client_example/05_contraction_bilinear/contraction_bilinear.cpp
0 → 100644
View file @
6c2b74de
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <numeric>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
Bilinear
;
using
ADataType
=
F32
;
using
BDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
DDataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<
DDataType
>
;
using
EDataType
=
F32
;
static
constexpr
ck
::
index_t
NumDimM
=
2
;
static
constexpr
ck
::
index_t
NumDimN
=
2
;
static
constexpr
ck
::
index_t
NumDimK
=
2
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
// A[M0, M1, K0, K1]
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
{
524288
,
4096
,
128
,
1
};
// B[N0, N1, K0, K1]
std
::
vector
<
ck
::
index_t
>
b_ns_ks_lengths
{
32
,
64
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
b_ns_ks_strides
{
524288
,
4096
,
128
,
1
};
// D[M0, M1, N0, N1]
std
::
vector
<
ck
::
index_t
>
d_ms_ns_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
d_ms_ns_strides
{
524288
,
4096
,
128
,
1
};
// E[M0, M1, N0, N1]
std
::
vector
<
ck
::
index_t
>
e_ms_ns_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
e_ms_ns_strides
{
524288
,
4096
,
128
,
1
};
float
alpha
=
1.
f
;
float
beta
=
1.
f
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
25
)
{
const
ck
::
index_t
M0
=
std
::
stoi
(
argv
[
1
]);
const
ck
::
index_t
M1
=
std
::
stoi
(
argv
[
2
]);
const
ck
::
index_t
N0
=
std
::
stoi
(
argv
[
3
]);
const
ck
::
index_t
N1
=
std
::
stoi
(
argv
[
4
]);
const
ck
::
index_t
K0
=
std
::
stoi
(
argv
[
5
]);
const
ck
::
index_t
K1
=
std
::
stoi
(
argv
[
6
]);
a_ms_ks_lengths
=
{
M0
,
M1
,
K0
,
K1
};
a_ms_ks_strides
=
{
std
::
stoi
(
argv
[
7
]),
std
::
stoi
(
argv
[
8
]),
std
::
stoi
(
argv
[
9
]),
std
::
stoi
(
argv
[
10
])};
b_ns_ks_lengths
=
{
N0
,
N1
,
K0
,
K1
};
b_ns_ks_strides
=
{
std
::
stoi
(
argv
[
11
]),
std
::
stoi
(
argv
[
12
]),
std
::
stoi
(
argv
[
13
]),
std
::
stoi
(
argv
[
14
])};
d_ms_ns_lengths
=
{
M0
,
M1
,
N0
,
N1
};
d_ms_ns_strides
=
{
std
::
stoi
(
argv
[
15
]),
std
::
stoi
(
argv
[
16
]),
std
::
stoi
(
argv
[
17
]),
std
::
stoi
(
argv
[
18
])};
e_ms_ns_lengths
=
{
M0
,
M1
,
N0
,
N1
};
e_ms_ns_strides
=
{
std
::
stoi
(
argv
[
19
]),
std
::
stoi
(
argv
[
20
]),
std
::
stoi
(
argv
[
21
]),
std
::
stoi
(
argv
[
22
])};
alpha
=
std
::
stof
(
argv
[
23
]);
beta
=
std
::
stof
(
argv
[
24
]);
}
else
{
printf
(
"arg1 to 6: M0, M1, N0, N1, K0, K1
\n
"
);
printf
(
"arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1
\n
"
);
printf
(
"arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1
\n
"
);
printf
(
"arg15 to 18: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1
\n
"
);
printf
(
"arg19 to 22: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1
\n
"
);
printf
(
"arg23 to 24: alpha, beta
\n
"
);
exit
(
0
);
}
auto
f_tensor_space_size
=
[](
auto
lengths
,
auto
strides
)
{
std
::
size_t
space_size
=
1
;
for
(
std
::
size_t
i
=
0
;
i
<
lengths
.
size
();
++
i
)
{
space_size
+=
(
lengths
[
i
]
-
1
)
*
strides
[
i
];
}
return
space_size
;
};
SimpleDeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
f_tensor_space_size
(
a_ms_ks_lengths
,
a_ms_ks_strides
));
SimpleDeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
f_tensor_space_size
(
b_ns_ks_lengths
,
b_ns_ks_strides
));
SimpleDeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
f_tensor_space_size
(
d_ms_ns_lengths
,
d_ms_ns_strides
));
SimpleDeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
f_tensor_space_size
(
e_ms_ns_lengths
,
e_ms_ns_strides
));
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceContractionMultipleD
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Bilinear
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
ck
::
index_t
M
=
std
::
accumulate
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
N
=
std
::
accumulate
(
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
e_ms_ns_lengths
.
begin
()
+
NumDimM
+
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
K
=
std
::
accumulate
(
a_ms_ks_lengths
.
begin
()
+
NumDimM
,
a_ms_ks_lengths
.
begin
()
+
NumDimM
+
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
DDataType
)
*
M
*
N
+
sizeof
(
EDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
0
;
}
client_example/CMakeLists.txt
View file @
6c2b74de
...
...
@@ -9,3 +9,4 @@ message(STATUS "Build with HIP ${hip_VERSION}")
add_subdirectory
(
01_gemm
)
add_subdirectory
(
02_gemm_add_add_fastgelu
)
add_subdirectory
(
03_gemm_layernorm
)
add_subdirectory
(
05_contraction_bilinear
)
example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp
View file @
6c2b74de
...
...
@@ -213,15 +213,15 @@ int main(int argc, char* argv[])
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
DDataType
>
{
-
0.5
,
0.5
});
}
DeviceMem
a_
m_k_
device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_
k_n_
device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d_
m_n_
device_buf
(
sizeof
(
DDataType
)
*
d_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
e_
m_n_
device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpace
());
a_
m_k_
device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_
k_n_
device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
d_
m_n_
device_buf
.
ToDevice
(
d_m_n
.
mData
.
data
());
e_
m_n_
device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_m_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
...
...
@@ -231,10 +231,10 @@ int main(int argc, char* argv[])
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
device_op
.
MakeArgument
(
a_
m_k_
device_buf
.
GetDeviceBuffer
(),
b_
k_n_
device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_
m_n_
device_buf
.
GetDeviceBuffer
()},
e_
m_n_
device_buf
.
GetDeviceBuffer
(),
device_op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
...
...
@@ -266,7 +266,7 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
e_
m_n_
device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
if
(
do_verification
)
{
...
...
@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
}
}
e_
m_n_
device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
e_m_n_device_result
.
mData
,
e_m_n_host_result
.
mData
)
?
0
:
1
;
}
...
...
example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp
View file @
6c2b74de
...
...
@@ -191,14 +191,14 @@ int main(int argc, char* argv[])
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
DDataType
>
{
0.0
,
1.0
});
}
DeviceMem
a_
m_k_
device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_
k_n_
device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d_
m_n_
device_buf
(
sizeof
(
DDataType
)
*
d_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
e_
m_n_
device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpace
());
a_
m_k_
device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_
k_n_
device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
d_
m_n_
device_buf
.
ToDevice
(
d_m_n
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_m_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
...
...
@@ -210,10 +210,10 @@ int main(int argc, char* argv[])
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
device_op
.
MakeArgument
(
a_
m_k_
device_buf
.
GetDeviceBuffer
(),
b_
k_n_
device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_
m_n_
device_buf
.
GetDeviceBuffer
()},
e_
m_n_
device_buf
.
GetDeviceBuffer
(),
device_op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
...
...
@@ -246,7 +246,7 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
e_
m_n_
device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
Tensor
<
AccDataType
>
c_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
...
...
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
View file @
6c2b74de
...
...
@@ -156,16 +156,16 @@ int main(int argc, char* argv[])
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
0.0
,
1.0
});
}
DeviceMem
a_
m_k_
device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_
k_n_
device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d0_
m_n_
device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d1_
m_n_
device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
e_
m_n_
device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpace
());
a_
m_k_
device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_
k_n_
device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
d0_
m_n_
device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d1_
m_n_
device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
...
...
@@ -175,11 +175,11 @@ int main(int argc, char* argv[])
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
device_op
.
MakeArgument
(
a_
m_k_
device_buf
.
GetDeviceBuffer
(),
b_
k_n_
device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
2
>
{
d0_
m_n_
device_buf
.
GetDeviceBuffer
(),
d1_
m_n_
device_buf
.
GetDeviceBuffer
()},
e_
m_n_
device_buf
.
GetDeviceBuffer
(),
device_op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
2
>
{
d0_device_buf
.
GetDeviceBuffer
(),
d1_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
...
...
@@ -239,7 +239,7 @@ int main(int argc, char* argv[])
}
}
e_
m_n_
device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
e_m_n_device_result
.
mData
,
e_m_n_host_result
.
mData
)
?
0
:
1
;
}
...
...
example/24_contraction_bilinear/contraction_bilinear_xdl_fp32.cpp
View file @
6c2b74de
...
...
@@ -69,13 +69,13 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
struct
Argument
:
public
ck
::
tensor_operation
::
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ADataType
>&
a_ms_ks
,
const
Tensor
<
BDataType
>&
b_
k
s_
n
s
,
const
Tensor
<
BDataType
>&
b_
n
s_
k
s
,
Tensor
<
EDataType
>&
e_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
a_ms_ks_
{
a_ms_ks
},
b_
k
s_
n
s_
{
b_
k
s_
n
s
},
b_
n
s_
k
s_
{
b_
n
s_
k
s
},
e_ms_ns_
{
e_ms_ns
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
...
...
@@ -84,7 +84,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
}
const
Tensor
<
ADataType
>&
a_ms_ks_
;
const
Tensor
<
BDataType
>&
b_
k
s_
n
s_
;
const
Tensor
<
BDataType
>&
b_
n
s_
k
s_
;
Tensor
<
EDataType
>&
e_ms_ns_
;
AElementwiseOperation
a_element_op_
;
...
...
@@ -115,7 +115,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
arg
.
a_element_op_
(
v_a
,
static_cast
<
const
AccDataType
>
(
arg
.
a_ms_ks_
(
m0
,
m1
,
k0
,
k1
)));
arg
.
b_element_op_
(
v_b
,
static_cast
<
const
AccDataType
>
(
arg
.
b_
k
s_
n
s_
(
k
0
,
k
1
,
n
0
,
n
1
)));
v_b
,
static_cast
<
const
AccDataType
>
(
arg
.
b_
n
s_
k
s_
(
n
0
,
n
1
,
k
0
,
k
1
)));
v_acc
+=
v_a
*
v_b
;
}
...
...
@@ -157,13 +157,13 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_ms_ks
,
const
Tensor
<
BDataType
>&
b_
k
s_
n
s
,
const
Tensor
<
BDataType
>&
b_
n
s_
k
s
,
Tensor
<
EDataType
>&
e_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
a_ms_ks
,
b_
k
s_
n
s
,
e_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
return
Argument
{
a_ms_ks
,
b_
n
s_
k
s
,
e_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -192,42 +192,86 @@ int main(int argc, char* argv[])
int
init_method
=
1
;
bool
time_kernel
=
false
;
if
(
argc
==
4
)
// A[M0, M1, K0, K1]
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
{
524288
,
4096
,
128
,
1
};
// B[N0, N1, K0, K1]
std
::
vector
<
ck
::
index_t
>
b_ns_ks_lengths
{
32
,
64
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
b_ns_ks_strides
{
524288
,
4096
,
128
,
1
};
// D[M0, M1, N0, N1]
std
::
vector
<
ck
::
index_t
>
d_ms_ns_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
d_ms_ns_strides
{
524288
,
4096
,
128
,
1
};
// E[M0, M1, N0, N1]
std
::
vector
<
ck
::
index_t
>
e_ms_ns_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
e_ms_ns_strides
{
524288
,
4096
,
128
,
1
};
float
alpha
=
1.
f
;
float
beta
=
1.
f
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
28
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
const
ck
::
index_t
M0
=
std
::
stoi
(
argv
[
4
]);
const
ck
::
index_t
M1
=
std
::
stoi
(
argv
[
5
]);
const
ck
::
index_t
N0
=
std
::
stoi
(
argv
[
6
]);
const
ck
::
index_t
N1
=
std
::
stoi
(
argv
[
7
]);
const
ck
::
index_t
K0
=
std
::
stoi
(
argv
[
8
]);
const
ck
::
index_t
K1
=
std
::
stoi
(
argv
[
9
]);
a_ms_ks_lengths
=
{
M0
,
M1
,
K0
,
K1
};
a_ms_ks_strides
=
{
std
::
stoi
(
argv
[
10
]),
std
::
stoi
(
argv
[
11
]),
std
::
stoi
(
argv
[
12
]),
std
::
stoi
(
argv
[
13
])};
b_ns_ks_lengths
=
{
N0
,
N1
,
K0
,
K1
};
b_ns_ks_strides
=
{
std
::
stoi
(
argv
[
14
]),
std
::
stoi
(
argv
[
15
]),
std
::
stoi
(
argv
[
16
]),
std
::
stoi
(
argv
[
17
])};
d_ms_ns_lengths
=
{
M0
,
M1
,
N0
,
N1
};
d_ms_ns_strides
=
{
std
::
stoi
(
argv
[
18
]),
std
::
stoi
(
argv
[
19
]),
std
::
stoi
(
argv
[
20
]),
std
::
stoi
(
argv
[
21
])};
e_ms_ns_lengths
=
{
M0
,
M1
,
N0
,
N1
};
e_ms_ns_strides
=
{
std
::
stoi
(
argv
[
22
]),
std
::
stoi
(
argv
[
23
]),
std
::
stoi
(
argv
[
24
]),
std
::
stoi
(
argv
[
25
])};
alpha
=
std
::
stof
(
argv
[
26
]);
beta
=
std
::
stof
(
argv
[
27
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 7: M0, M1, N0, N1, K0, K1
\n
"
);
printf
(
"arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1
\n
"
);
printf
(
"arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1
\n
"
);
printf
(
"arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1
\n
"
);
printf
(
"arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1
\n
"
);
printf
(
"arg26 to 27: alpha, beta
\n
"
);
exit
(
0
);
}
const
float
alpha
=
1
;
const
float
beta
=
1
;
// A[M0, M1, K0, K1]
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
{
524288
,
4096
,
128
,
1
};
// B[K0, K1, N0, N1]
std
::
vector
<
ck
::
index_t
>
b_ks_ns_lengths
{
32
,
64
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
b_ks_ns_strides
{
128
,
1
,
524288
,
4096
};
// D[M0, M1, N0, N1]
std
::
vector
<
ck
::
index_t
>
d_ms_ns_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
d_ms_ns_strides
{
524288
,
4096
,
128
,
1
};
// E[M0, M1, N0, N1]
std
::
vector
<
ck
::
index_t
>
e_ms_ns_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
e_ms_ns_strides
{
524288
,
4096
,
128
,
1
};
Tensor
<
ADataType
>
a_ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_lengths
.
begin
(),
a_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_strides
.
begin
(),
a_ms_ks_strides
.
end
()));
Tensor
<
BDataType
>
b_
k
s_
n
s
(
std
::
vector
<
std
::
size_t
>
(
b_
k
s_
n
s_lengths
.
begin
(),
b_
k
s_
n
s_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_
k
s_
n
s_strides
.
begin
(),
b_
k
s_
n
s_strides
.
end
()));
Tensor
<
BDataType
>
b_
n
s_
k
s
(
std
::
vector
<
std
::
size_t
>
(
b_
n
s_
k
s_lengths
.
begin
(),
b_
n
s_
k
s_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_
n
s_
k
s_strides
.
begin
(),
b_
n
s_
k
s_strides
.
end
()));
Tensor
<
EDataType
>
d_ms_ns
(
std
::
vector
<
std
::
size_t
>
(
d_ms_ns_lengths
.
begin
(),
d_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
d_ms_ns_strides
.
begin
(),
d_ms_ns_strides
.
end
()));
...
...
@@ -239,7 +283,7 @@ int main(int argc, char* argv[])
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
std
::
cout
<<
"a_ms_ks: "
<<
a_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_
k
s_
n
s: "
<<
b_
k
s_
n
s
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_
n
s_
k
s: "
<<
b_
n
s_
k
s
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_ms_ns: "
<<
d_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_ms_ns: "
<<
e_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
...
...
@@ -248,55 +292,50 @@ int main(int argc, char* argv[])
case
0
:
break
;
case
1
:
a_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_
k
s_
n
s
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
b_
n
s_
k
s
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
case
2
:
default
:
a_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_
k
s_
n
s
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_
n
s_
k
s
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
default:
a_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b_ks_ns
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
d_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
}
DeviceMem
a_
ms_ks_
device_buf
(
sizeof
(
ADataType
)
*
a_ms_ks
.
mDesc
.
GetElementSpace
());
DeviceMem
b_
ks_ns_
device_buf
(
sizeof
(
BDataType
)
*
b_
k
s_
n
s
.
mDesc
.
GetElementSpace
());
DeviceMem
d_
ms_ns_
device_buf
(
sizeof
(
DDataType
)
*
d_ms_ns
.
mDesc
.
GetElementSpace
());
DeviceMem
e_
ms_ns_
device_buf
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_ms_ks
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_
n
s_
k
s
.
mDesc
.
GetElementSpace
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_ms_ns
.
mDesc
.
GetElementSpace
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result
.
mDesc
.
GetElementSpace
());
a_
ms_ks_
device_buf
.
ToDevice
(
a_ms_ks
.
mData
.
data
());
b_
ks_ns_
device_buf
.
ToDevice
(
b_
k
s_
n
s
.
mData
.
data
());
d_
ms_ns_
device_buf
.
ToDevice
(
d_ms_ns
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_
n
s_
k
s
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_ms_ns
.
mData
.
data
());
// set zero
e_
ms_ns_
device_buf
.
SetZero
();
e_device_buf
.
SetZero
();
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
// device operation
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument
=
op
.
MakeArgument
(
a_ms_ks_device_buf
.
GetDeviceBuffer
(),
b_ks_ns_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_ms_ns_device_buf
.
GetDeviceBuffer
()},
e_ms_ns_device_buf
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ks_ns_lengths
,
b_ks_ns_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument
=
op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -333,7 +372,7 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op
.
GetTypeString
()
<<
std
::
endl
;
e_
ms_ns_
device_buf
.
FromDevice
(
e_ms_ns_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_ms_ns_device_result
.
mData
.
data
());
if
(
do_verification
)
{
...
...
@@ -356,7 +395,7 @@ int main(int argc, char* argv[])
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_ms_ks
,
b_
k
s_
n
s
,
c_ms_ns_host_result
,
a_element_op
,
b_element_op
,
PassThrough
{});
a_ms_ks
,
b_
n
s_
k
s
,
c_ms_ns_host_result
,
a_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
...
...
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp
View file @
6c2b74de
...
...
@@ -20,10 +20,10 @@ namespace device {
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2...]
// B[
K
0,
K
1,
K
2, ...,
N
0,
N
1,
N2
...]
// D[M0, M1, M2, ..., N0, N1, N2...]
// E[M0, M1, M2, ..., N0, N1, N2...]
// A[M0, M1, M2, ..., K0, K1, K2
,
...]
// B[
N
0,
N
1,
N
2, ...,
K
0,
K
1,
K2,
...]
// D[M0, M1, M2, ..., N0, N1, N2
,
...]
// E[M0, M1, M2, ..., N0, N1, N2
,
...]
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
...
...
@@ -43,14 +43,14 @@ struct DeviceContractionMultipleD : public BaseOperator
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
std
::
vector
<
index_t
>
a_lengths
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
b_lengths
,
std
::
vector
<
index_t
>
b_strides
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_lengths
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_strides
,
std
::
vector
<
index_t
>
e_lengths
,
std
::
vector
<
index_t
>
e_strides
,
std
::
vector
<
index_t
>
a_
ms_ks_
lengths
,
std
::
vector
<
index_t
>
a_
ms_ks_
strides
,
std
::
vector
<
index_t
>
b_
ns_ks_
lengths
,
std
::
vector
<
index_t
>
b_
ns_ks_
strides
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_
ms_ns_
lengths
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_
ms_ns_
strides
,
std
::
vector
<
index_t
>
e_
ms_ns_
lengths
,
std
::
vector
<
index_t
>
e_
ms_ns_
strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
...
...
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
6c2b74de
...
...
@@ -97,10 +97,10 @@ namespace device {
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2...]
// B[
K
0,
K
1,
K
2, ...,
N
0,
N
1,
N2
...]
// D[M0, M1, M2, ..., N0, N1, N2...]
// E[M0, M1, M2, ..., N0, N1, N2...]
// A[M0, M1, M2, ..., K0, K1, K2
,
...]
// B[
N
0,
N
1,
N
2, ...,
K
0,
K
1,
K2,
...]
// D[M0, M1, M2, ..., N0, N1, N2
,
...]
// E[M0, M1, M2, ..., N0, N1, N2
,
...]
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
...
...
@@ -164,19 +164,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
//
a
ssume A[M0, M1, M2, ..., K0, K1, K2...]
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_strides_vec
)
//
A
ssume
:
A[M0, M1, M2, ..., K0, K1, K2
,
...]
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_
ms_ks_
lengths_vec
,
const
std
::
vector
<
index_t
>&
a_
ms_ks_
strides_vec
)
{
assert
(
a_lengths_vec
.
size
()
==
NumDimM
+
NumDimK
&&
a_strides_vec
.
size
()
==
NumDimM
+
NumDimK
);
assert
(
a_
ms_ks_
lengths_vec
.
size
()
==
NumDimM
+
NumDimK
&&
a_
ms_ks_
strides_vec
.
size
()
==
NumDimM
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
Num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
Num
);
};
const
auto
a_lengths
=
to_tuple
(
a_lengths_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
const
auto
a_strides
=
to_tuple
(
a_strides_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
const
auto
a_
ms_ns_
lengths
=
to_tuple
(
a_
ms_ks_
lengths_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
const
auto
a_
ms_ks_
strides
=
to_tuple
(
a_
ms_ks_
strides_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
...
...
@@ -186,13 +186,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimK
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
a_lengths
,
mDimIds
);
const
auto
mLengths
=
get_container_subset
(
a_
ms_ns_
lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
a_lengths
,
kDimIds
);
const
auto
kLengths
=
get_container_subset
(
a_
ms_ns_
lengths
,
kDimIds
);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_lengths
,
a_strides
);
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_ms_ns_lengths
,
a_ms_ks_strides
);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
...
...
@@ -292,42 +293,43 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
}
}
//
a
ssume B[
K
0,
K
1,
K
2, ...,
N
0,
N
1,
N2
...]
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_strides_vec
)
//
A
ssume
:
B[
N
0,
N
1,
N
2, ...,
K
0,
K
1,
K2,
...]
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_
ns_ks_
lengths_vec
,
const
std
::
vector
<
index_t
>&
b_
ns_ks_
strides_vec
)
{
assert
(
b_lengths_vec
.
size
()
==
NumDimN
+
NumDimK
&&
b_strides_vec
.
size
()
==
NumDimN
+
NumDimK
);
assert
(
b_
ns_ks_
lengths_vec
.
size
()
==
NumDimN
+
NumDimK
&&
b_
ns_ks_
strides_vec
.
size
()
==
NumDimN
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
Num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
Num
);
};
const
auto
b_lengths
=
to_tuple
(
b_lengths_vec
,
Number
<
NumDimN
+
NumDimK
>
{});
const
auto
b_strides
=
to_tuple
(
b_strides_vec
,
Number
<
NumDimN
+
NumDimK
>
{});
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimK
,
1
>::
type
{};
const
auto
b_ns_ks_lengths
=
to_tuple
(
b_ns_ks_lengths_vec
,
Number
<
NumDimN
+
NumDimK
>
{});
const
auto
b_ns_ks_strides
=
to_tuple
(
b_ns_ks_strides_vec
,
Number
<
NumDimN
+
NumDimK
>
{});
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimK
,
NumDimK
+
NumDimN
,
1
>::
type
{};
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimN
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimN
,
NumDimN
+
NumDimK
,
1
>::
type
{};
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
b_lengths
,
kDimIds
);
const
auto
kLengths
=
get_container_subset
(
b_
ns_ks_
lengths
,
kDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
b_lengths
,
nDimIds
);
const
auto
nLengths
=
get_container_subset
(
b_
ns_ks_
lengths
,
nDimIds
);
// naive tensor B[K0, K1, K2..., N0, N1, N2, ...]
const
auto
b_grid_desc_ks_ns
=
make_naive_tensor_descriptor
(
b_lengths
,
b_strides
);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const
auto
b_grid_desc_ns_ks
=
make_naive_tensor_descriptor
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
// transformed tensor B[
K
Raw =
K
0 *
K
1 *
K
2 * ...
,
N
Raw =
N
0 *
N
1 *
N
2 * ...
]
// transformed tensor B[
N
Raw =
N
0 *
N
1 *
N
2 * ...,
K
Raw =
K
0 *
K
1 *
K
2 * ...]
const
auto
b_grid_desc_nraw_kraw
=
transform_tensor_descriptor
(
b_grid_desc_
k
s_
n
s
,
make_tuple
(
make_merge_transform
(
k
Lengths
),
make_merge_transform
(
n
Lengths
)),
make_tuple
(
k
DimIds
,
n
DimIds
),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
b_grid_desc_
n
s_
k
s
,
make_tuple
(
make_merge_transform
(
n
Lengths
),
make_merge_transform
(
k
Lengths
)),
make_tuple
(
n
DimIds
,
k
DimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
NRaw
=
b_grid_desc_nraw_kraw
.
GetLength
(
I0
);
const
auto
KRaw
=
b_grid_desc_nraw_kraw
.
GetLength
(
I1
);
...
...
@@ -421,18 +423,18 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
}
// assume E[M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
e_lengths_vec
,
const
std
::
vector
<
index_t
>&
e_strides_vec
)
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
e_
ms_ns_
lengths_vec
,
const
std
::
vector
<
index_t
>&
e_
ms_ns_
strides_vec
)
{
assert
(
e_lengths_vec
.
size
()
==
NumDimM
+
NumDimN
&&
e_strides_vec
.
size
()
==
NumDimM
+
NumDimN
);
assert
(
e_
ms_ns_
lengths_vec
.
size
()
==
NumDimM
+
NumDimN
&&
e_
ms_ns_
strides_vec
.
size
()
==
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
Num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
Num
);
};
const
auto
e_lengths
=
to_tuple
(
e_lengths_vec
,
Number
<
NumDimM
+
NumDimN
>
{});
const
auto
e_strides
=
to_tuple
(
e_strides_vec
,
Number
<
NumDimM
+
NumDimN
>
{});
const
auto
e_
ms_ns_
lengths
=
to_tuple
(
e_
ms_ns_
lengths_vec
,
Number
<
NumDimM
+
NumDimN
>
{});
const
auto
e_
ms_ns_
strides
=
to_tuple
(
e_
ms_ns_
strides_vec
,
Number
<
NumDimM
+
NumDimN
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
...
...
@@ -442,13 +444,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
e_lengths
,
mDimIds
);
const
auto
mLengths
=
get_container_subset
(
e_
ms_ns_
lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
e_lengths
,
nDimIds
);
const
auto
nLengths
=
get_container_subset
(
e_
ms_ns_
lengths
,
nDimIds
);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const
auto
e_grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
e_lengths
,
e_strides
);
const
auto
e_grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const
auto
e_grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
...
...
@@ -564,14 +567,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
void
*
p_b_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
std
::
vector
<
index_t
>
a_lengths
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
b_lengths
,
std
::
vector
<
index_t
>
b_strides
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_lengths
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_strides
,
std
::
vector
<
index_t
>
e_lengths
,
std
::
vector
<
index_t
>
e_strides
,
std
::
vector
<
index_t
>
a_
ms_ns_
lengths
,
std
::
vector
<
index_t
>
a_
ms_ks_
strides
,
std
::
vector
<
index_t
>
b_
ns_ks_
lengths
,
std
::
vector
<
index_t
>
b_
ns_ks_
strides
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_
ms_ns_
lengths
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_
ms_ns_
strides
,
std
::
vector
<
index_t
>
e_
ms_ns_
lengths
,
std
::
vector
<
index_t
>
e_
ms_ns_
strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
...
...
@@ -579,10 +582,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
// FIXME
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_lengths
,
a_strides
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_lengths
,
b_strides
)},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_ms_ns_lengths
,
a_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_ns_ks_lengths
,
b_ns_ks_strides
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_lengths
,
e_strides
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_
ms_ns_
lengths
,
e_
ms_ns_
strides
)},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
...
...
@@ -616,7 +621,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
const
auto
d_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
ds_lengths
[
i
],
ds_strides
[
i
]);
DeviceOp
::
MakeEGridDescriptor_M_N
(
ds_
ms_ns_
lengths
[
i
],
ds_
ms_ns_
strides
[
i
]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
i
)
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
...
@@ -624,24 +629,33 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
});
}
// for sanity check of vector load/store
a_mz_length_
=
a_lengths
[
NumDimM
-
1
];
a_mz_stride_
=
a_strides
[
NumDimM
-
1
];
// for sanity check of vector memory access
a_mz_length_
=
a_ms_ns_lengths
[
NumDimM
-
1
];
a_mz_stride_
=
a_ms_ks_strides
[
NumDimM
-
1
];
a_kz_length_
=
a_ms_ns_lengths
[
NumDimM
+
NumDimK
-
1
];
a_kz_stride_
=
a_ms_ks_strides
[
NumDimM
+
NumDimK
-
1
];
a_k
z_length_
=
a
_lengths
[
NumDim
M
+
NumDimK
-
1
];
a_k
z_stride_
=
a
_strides
[
NumDim
M
+
NumDimK
-
1
];
b_n
z_length_
=
b_ns_ks
_lengths
[
NumDim
N
-
1
];
b_n
z_stride_
=
b_ns_ks
_strides
[
NumDim
N
-
1
];
b_
n
z_length_
=
b_lengths
[
NumDimK
-
1
];
b_
n
z_stride_
=
b_strides
[
NumDimK
-
1
];
b_
k
z_length_
=
b_
ns_ks_
lengths
[
NumDimN
+
NumDimK
-
1
];
b_
k
z_stride_
=
b_
ns_ks_
strides
[
NumDimN
+
NumDimK
-
1
];
b_kz_length_
=
b_lengths
[
NumDimK
+
NumDimN
-
1
];
b_kz_stride_
=
b_strides
[
NumDimK
+
NumDimN
-
1
];
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_mz_length_
[
i
]
=
ds_ms_ns_lengths
[
i
][
NumDimM
-
1
];
ds_mz_stride_
[
i
]
=
ds_ms_ns_strides
[
i
][
NumDimM
-
1
];
ds_nz_length_
[
i
]
=
ds_ms_ns_lengths
[
i
][
NumDimM
+
NumDimN
-
1
];
ds_nz_stride_
[
i
]
=
ds_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
];
}
e_mz_length_
=
b
_lengths
[
NumDimM
-
1
];
e_mz_stride_
=
b
_strides
[
NumDimM
-
1
];
e_mz_length_
=
e_ms_ns
_lengths
[
NumDimM
-
1
];
e_mz_stride_
=
e_ms_ns
_strides
[
NumDimM
-
1
];
e_nz_length_
=
b
_lengths
[
NumDimM
+
NumDimN
-
1
];
e_nz_stride_
=
b
_strides
[
NumDimM
+
NumDimN
-
1
];
e_nz_length_
=
e_ms_ns
_lengths
[
NumDimM
+
NumDimN
-
1
];
e_nz_stride_
=
e_ms_ns
_strides
[
NumDimM
+
NumDimN
-
1
];
}
// private:
...
...
@@ -682,6 +696,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
index_t
b_nz_stride_
;
index_t
b_kz_length_
;
index_t
b_kz_stride_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_mz_length_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_mz_stride_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_nz_length_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_nz_stride_
;
index_t
e_mz_length_
;
index_t
e_mz_stride_
;
index_t
e_nz_length_
;
...
...
@@ -823,14 +841,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
std
::
vector
<
index_t
>
a_lengths
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
b_lengths
,
std
::
vector
<
index_t
>
b_strides
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_lengths
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_strides
,
std
::
vector
<
index_t
>
e_lengths
,
std
::
vector
<
index_t
>
e_strides
,
std
::
vector
<
index_t
>
a_
ms_ns_
lengths
,
std
::
vector
<
index_t
>
a_
ms_ks_
strides
,
std
::
vector
<
index_t
>
b_
ns_ks_
lengths
,
std
::
vector
<
index_t
>
b_
ns_ks_
strides
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_
ms_ns_
lengths
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_
ms_ns_
strides
,
std
::
vector
<
index_t
>
e_
ms_ns_
lengths
,
std
::
vector
<
index_t
>
e_
ms_ns_
strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
...
...
@@ -839,14 +857,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b
,
p_ds
,
p_e
,
a_lengths
,
a_strides
,
b_lengths
,
b_strides
,
ds_lengths
,
ds_strides
,
e_lengths
,
e_strides
,
a_
ms_ns_
lengths
,
a_
ms_ks_
strides
,
b_
ns_ks_
lengths
,
b_
ns_ks_
strides
,
ds_
ms_ns_
lengths
,
ds_
ms_ns_
strides
,
e_
ms_ns_
lengths
,
e_
ms_ns_
strides
,
a_element_op
,
b_element_op
,
cde_element_op
};
...
...
@@ -860,14 +878,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
std
::
vector
<
index_t
>
a_lengths
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
b_lengths
,
std
::
vector
<
index_t
>
b_strides
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_lengths
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_strides
,
std
::
vector
<
index_t
>
e_lengths
,
std
::
vector
<
index_t
>
e_strides
,
std
::
vector
<
index_t
>
a_
ms_ns_
lengths
,
std
::
vector
<
index_t
>
a_
ms_ks_
strides
,
std
::
vector
<
index_t
>
b_
ns_ks_
lengths
,
std
::
vector
<
index_t
>
b_
ns_ks_
strides
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_
ms_ns_
lengths
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_
ms_ns_
strides
,
std
::
vector
<
index_t
>
e_
ms_ns_
lengths
,
std
::
vector
<
index_t
>
e_
ms_ns_
strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
...
...
@@ -876,14 +894,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b
,
p_ds
,
p_e
,
a_lengths
,
a_strides
,
b_lengths
,
b_strides
,
ds_lengths
,
ds_strides
,
e_lengths
,
e_strides
,
a_
ms_ns_
lengths
,
a_
ms_ks_
strides
,
b_
ns_ks_
lengths
,
b_
ns_ks_
strides
,
ds_
ms_ns_
lengths
,
ds_
ms_ns_
strides
,
e_
ms_ns_
lengths
,
e_
ms_ns_
strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
6c2b74de
...
...
@@ -19,6 +19,8 @@ using BF16 = ck::bhalf_t;
using
F16_TUPLE
=
ck
::
Tuple
<
F16
>
;
using
F16_F16_TUPLE
=
ck
::
Tuple
<
F16
,
F16
>
;
using
F32_TUPLE
=
ck
::
Tuple
<
F32
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
0 → 100644
View file @
6c2b74de
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_TUPLE
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_TUPLE
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_TUPLE
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_TUPLE
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
// Contraction + Bilinear
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
DDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceContractionMultipleD
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Bilinear
>>
{
using
DeviceOp
=
DeviceContractionMultipleD
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Bilinear
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
DDataType
,
float
>
&&
is_same_v
<
EDataType
,
float
>
)
{
if
constexpr
(
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
View file @
6c2b74de
...
...
@@ -32,8 +32,8 @@ using Bilinear = ck::tensor_operation::element_wise::Bilinear;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
//
Compilation parameters for
A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1,
//
n0, n1]
k/k/n/n are the fast changing dimension for A/B/D/E
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1,
n0, n1]
// k/k/n/n are the fast changing dimension for A/B/D/E
using
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance
=
std
::
tuple
<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...
...
library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
View file @
6c2b74de
...
...
@@ -32,8 +32,8 @@ using Bilinear = ck::tensor_operation::element_wise::Bilinear;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
//
Compilation parameters for
A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1,
//
n0, n1]
k/n/n/n are the fast changing dimension for A/B/D/E
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1,
n0, n1]
// k/n/n/n are the fast changing dimension for A/B/D/E
using
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance
=
std
::
tuple
<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...
...
library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
View file @
6c2b74de
...
...
@@ -32,8 +32,8 @@ using Bilinear = ck::tensor_operation::element_wise::Bilinear;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
//
Compilation parameters for
A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1,
//
n0, n1]
m/k/n/n are the fast changing dimension for A/B/D/E
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1,
n0, n1]
// m/k/n/n are the fast changing dimension for A/B/D/E
using
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance
=
std
::
tuple
<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...
...
library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp
View file @
6c2b74de
...
...
@@ -32,8 +32,8 @@ using Bilinear = ck::tensor_operation::element_wise::Bilinear;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
//
Compilation parameters for
A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1,
//
n0, n1]
m/n/n/n are the fast changing dimension for A/B/D/E
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1,
n0, n1]
// m/n/n/n are the fast changing dimension for A/B/D/E
using
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance
=
std
::
tuple
<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment