Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f60ad8b9
Commit
f60ad8b9
authored
Sep 21, 2023
by
Jing Zhang
Browse files
add examples of multiA and broadcast
parent
0512580c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
115 additions
and
170 deletions
+115
-170
example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp
..._contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp
+114
-167
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
...ice/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
+1
-3
No files found.
example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp
View file @
f60ad8b9
...
@@ -15,8 +15,9 @@
...
@@ -15,8 +15,9 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_
gemm
.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_
contraction
.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/numeric.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -29,7 +30,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -29,7 +30,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
A0DataType
=
F16
;
using
A1DataType
=
F32
;
using
BDataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
...
@@ -40,44 +42,6 @@ static constexpr ck::index_t NumDimM = 2;
...
@@ -40,44 +42,6 @@ static constexpr ck::index_t NumDimM = 2;
static
constexpr
ck
::
index_t
NumDimN
=
2
;
static
constexpr
ck
::
index_t
NumDimN
=
2
;
static
constexpr
ck
::
index_t
NumDimK
=
2
;
static
constexpr
ck
::
index_t
NumDimK
=
2
;
//struct AddScale
//{
//static constexpr auto I0 = ck::Number<0>{};
//static constexpr auto I1 = ck::Number<1>{};
//static constexpr auto I2 = ck::Number<2>{};
//static constexpr auto I3 = ck::Number<3>{};
//__host__ __device__ constexpr void
//operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const
//{
//const auto a0_v_t = ck::vector_type<ck::half_t, 4>{a0};
//const auto a1_v_t = ck::vector_type<ck::half_t, 4>{a1};
//auto r_v_t = ck::vector_type<ck::half_t, 4>{};
//r_v_t.AsType<ck::half_t>()(I0) =
//scale * (a0_v_t.AsType<ck::half_t>()[I0] + a1_v_t.AsType<ck::half_t>()[I0]);
//r_v_t.AsType<ck::half_t>()(I1) =
//scale * (a0_v_t.AsType<ck::half_t>()[I1] + a1_v_t.AsType<ck::half_t>()[I1]);
//r_v_t.AsType<ck::half_t>()(I2) =
//scale * (a0_v_t.AsType<ck::half_t>()[I2] + a1_v_t.AsType<ck::half_t>()[I2]);
//r_v_t.AsType<ck::half_t>()(I3) =
//scale * (a0_v_t.AsType<ck::half_t>()[I3] + a1_v_t.AsType<ck::half_t>()[I3]);
//a = r_v_t.AsType<ck::half4_t>()[I0];
//}
//__host__ __device__ constexpr void
//operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const
//{
//a = scale * (a0 + a1);
//}
//static constexpr ck::index_t vec_len = 4;
//float scale = 1.0;
//};
struct
AlphaBetaAdd
struct
AlphaBetaAdd
{
{
AlphaBetaAdd
(
float
alpha
,
float
beta
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
AlphaBetaAdd
(
float
alpha
,
float
beta
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
...
@@ -96,17 +60,26 @@ struct AlphaBetaAdd
...
@@ -96,17 +60,26 @@ struct AlphaBetaAdd
float
beta_
;
float
beta_
;
};
};
using
AElementOp
=
PassThrough
;
struct
Multiply
{
__host__
__device__
constexpr
void
operator
()(
ck
::
half_t
&
a
,
const
ck
::
half_t
&
a0
,
const
float
&
a1
)
const
{
a
=
a0
*
a1
;
}
};
using
AElementOp
=
Multiply
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
AlphaBetaAdd
;
using
CDEElementOp
=
AlphaBetaAdd
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceContractionMultipleABD_Xdl_CShuffle
<
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceContractionMultipleABD_Xdl_CShuffle
<
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
ck
::
Tuple
<
ADataType
>
,
ck
::
Tuple
<
A
0DataType
,
A1
DataType
>
,
ck
::
Tuple
<
BDataType
>
,
ck
::
Tuple
<
BDataType
>
,
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -152,22 +125,15 @@ int main(int argc, char* argv[])
...
@@ -152,22 +125,15 @@ int main(int argc, char* argv[])
int
init_method
=
1
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
//// GEMM shape
//ck::index_t M = 3840;
//ck::index_t N = 4096;
//ck::index_t K = 4096;
//ck::index_t StrideA = 4096;
//ck::index_t StrideB = 4096;
//ck::index_t StrideD = 4096;
//ck::index_t StrideE = 4096;
float
alpha
=
1.0
f
;
float
alpha
=
1.0
f
;
float
beta
=
1.0
f
;
float
beta
=
1.0
f
;
// A[M0, M1, K0, K1]
// A0[M0, M1, K0, K1]
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
a0_ms_ks_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
{
524288
,
4096
,
128
,
1
};
std
::
vector
<
ck
::
index_t
>
a0_ms_ks_strides
{
524288
,
4096
,
128
,
1
};
// A1[M1, K1] -> A1[M0, M1, K0, K1]
std
::
vector
<
ck
::
index_t
>
a1_ms_ks_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
a1_ms_ks_strides
{
0
,
64
,
0
,
1
};
// B[N0, N1, K0, K1]
// B[N0, N1, K0, K1]
std
::
vector
<
ck
::
index_t
>
b_ns_ks_lengths
{
32
,
64
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
b_ns_ks_lengths
{
32
,
64
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
b_ns_ks_strides
{
524288
,
4096
,
128
,
1
};
std
::
vector
<
ck
::
index_t
>
b_ns_ks_strides
{
524288
,
4096
,
128
,
1
};
...
@@ -188,64 +154,23 @@ int main(int argc, char* argv[])
...
@@ -188,64 +154,23 @@ int main(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
//else if(argc == 6)
//{
//do_verification = std::stoi(argv[1]);
//init_method = std::stoi(argv[2]);
//time_kernel = std::stoi(argv[3]);
//alpha = std::stof(argv[4]);
//beta = std::stof(argv[5]);
//}
//else if(argc == 13)
//{
//do_verification = std::stoi(argv[1]);
//init_method = std::stoi(argv[2]);
//time_kernel = std::stoi(argv[3]);
//M = std::stoi(argv[4]);
//N = std::stoi(argv[5]);
//K = std::stoi(argv[6]);
//StrideA = std::stoi(argv[7]);
//StrideB = std::stoi(argv[8]);
//StrideD = std::stoi(argv[9]);
//StrideE = std::stoi(argv[10]);
//alpha = std::stof(argv[11]);
//beta = std::stof(argv[12]);
//}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, "
"beta
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
//auto f_host_tensor_descriptor =
Tensor
<
A0DataType
>
a0_ms_ks
(
a0_ms_ks_lengths
,
a0_ms_ks_strides
);
//[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
Tensor
<
A1DataType
>
a1_ms_ks
(
a1_ms_ks_lengths
,
a1_ms_ks_strides
);
//using namespace ck::literals;
//if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
//{
//return HostTensorDescriptor({row, col}, {stride, 1_uz});
//}
//else
//{
//return HostTensorDescriptor({row, col}, {1_uz, stride});
//}
//};
Tensor
<
ADataType
>
a_ms_ks
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
Tensor
<
BDataType
>
b_ns_ks
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
Tensor
<
BDataType
>
b_ns_ks
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
Tensor
<
EDataType
>
d_ms_ns
(
d_ms_ns_lengths
,
d_ms_ns_strides
);
Tensor
<
EDataType
>
d_ms_ns
(
d_ms_ns_lengths
,
d_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_host_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_host_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_device_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_device_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
std
::
cout
<<
"a_ms_ks: "
<<
a_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a0_ms_ks: "
<<
a0_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a1_ms_ks: "
<<
a1_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_ns_ks: "
<<
b_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_ns_ks: "
<<
b_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_ms_ns: "
<<
d_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_ms_ns: "
<<
d_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_ms_ns: "
<<
e_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_ms_ns: "
<<
e_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
...
@@ -254,35 +179,27 @@ int main(int argc, char* argv[])
...
@@ -254,35 +179,27 @@ int main(int argc, char* argv[])
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
a_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
a0_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
5
,
5
});
a1_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
A1DataType
>
{
-
5
,
5
});
b_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
b_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
break
;
default:
default:
a_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a0_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
a1_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
A1DataType
>
{
0.0
,
1.0
});
b_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
break
;
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a1_device_buf
(
sizeof
(
A1DataType
)
*
a1_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
//Tensor<ADataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
//Tensor<ADataType> a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
a0_device_buf
.
ToDevice
(
a0_ms_ks
.
mData
.
data
());
//Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
a1_device_buf
.
ToDevice
(
a1_ms_ks
.
mData
.
data
());
//Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
//Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
//Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
//std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
//std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl;
//std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
//std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
//std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
a_device_buf
.
ToDevice
(
a_ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_ns_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_ms_ns
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_ms_ns
.
mData
.
data
());
...
@@ -296,13 +213,14 @@ int main(int argc, char* argv[])
...
@@ -296,13 +213,14 @@ int main(int argc, char* argv[])
// do GEMM
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
auto
argument
=
device_op
.
MakeArgument
(
device_op
.
MakeArgument
(
std
::
array
<
const
void
*
,
1
>
{
a_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
2
>
{
a0_device_buf
.
GetDeviceBuffer
(),
a1_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
1
>
{
b_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
1
>
{
b_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
e_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
a_ms_ks_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
a
0_ms_ks_lengths
,
a1
_ms_ks_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
a_ms_ks_strides
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
a
0_ms_ks_strides
,
a1
_ms_ks_strides
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
b_ns_ks_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
b_ns_ks_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
b_ns_ks_strides
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
b_ns_ks_strides
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
...
@@ -320,64 +238,93 @@ int main(int argc, char* argv[])
...
@@ -320,64 +238,93 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
//float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
//std::size_t flop = std::size_t(2) * M * N * K;
if
(
time_kernel
)
//std::size_t num_btype =
{
//sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
ck
::
index_t
M
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
e_ms_ns_lengths
.
begin
(),
NumDimM
,
1
,
std
::
multiplies
<>
{});
//float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
ck
::
index_t
N
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
NumDimN
,
1
,
std
::
multiplies
<>
{});
//float gb_per_sec = num_btype / 1.E6 / ave_time;
ck
::
index_t
K
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
a0_ms_ks_lengths
.
begin
()
+
NumDimM
,
NumDimK
,
1
,
std
::
multiplies
<>
{});
//std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
//<< std::endl;
std
::
size_t
num_btype
=
sizeof
(
A0DataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
+
sizeof
(
EDataType
)
*
M
*
N
;
//e_device_buf.FromDevice(e_m_n_device_result.mData.data());
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
}
if
(
do_verification
)
if
(
do_verification
)
{
{
#if 0
Tensor<CShuffleDataType> c_m_n({M, N});
Tensor<ADataType> a_m_k({M, K});
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
A0DataType
>
a_ms_ks
(
a0_ms_ks_lengths
,
a0_ms_ks_strides
);
for(
in
t m = 0; m <
M
; ++m)
for
(
size_
t
m
0
=
0
;
m
0
<
a_ms_ks
.
mDesc
.
GetLengths
()[
0
]
;
++
m
0
)
{
{
for(
int k
= 0;
k
<
K
; ++
k
)
for
(
size_t
m1
=
0
;
m1
<
a_ms_ks
.
mDesc
.
GetLengths
()[
1
]
;
++
m1
)
{
{
a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k));
for
(
size_t
k0
=
0
;
k0
<
a_ms_ks
.
mDesc
.
GetLengths
()[
2
];
++
k0
)
{
for
(
size_t
k1
=
0
;
k1
<
a_ms_ks
.
mDesc
.
GetLengths
()[
3
];
++
k1
)
{
a_element_op
(
a_ms_ks
(
m0
,
m1
,
k0
,
k1
),
a0_ms_ks
(
m0
,
m1
,
k0
,
k1
),
a1_ms_ks
(
m0
,
m1
,
k0
,
k1
));
}
}
}
}
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
using
ReferenceOpInstance
=
ck
::
tensor_operation
::
host
::
ReferenceContraction_M2_N2_K2
<
NumDimM
,
NumDimN
,
NumDimK
,
A0DataType
,
BDataType
,
BDataType
,
CShuffleDataType
,
CShuffleDataType
,
AccDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
BElementOp
,
BElementOp
>
;
PassThrough>;
auto ref_
gemm
= Reference
Gemm
Instance{};
auto
ref_
op
=
Reference
Op
Instance
{};
auto ref_invoker
= ref_
gemm
.MakeInvoker();
auto
ref_invoker
=
ref_
op
.
MakeInvoker
();
Tensor
<
float
>
empty_tensor
(
std
::
vector
<
ck
::
index_t
>
{},
std
::
vector
<
ck
::
index_t
>
{});
auto
ref_argument
=
auto
ref_argument
=
ref_
gemm
.MakeArgument(a_m_k, b_
k_n
, c_m_n, PassThrough{}, b_element_op
, PassThrough{}
);
ref_
op
.
MakeArgument
(
a_m
s
_k
s
,
b_
ns_ks
,
c_m
s
_n
s_host_result
,
PassThrough
{},
b_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
for(
in
t m = 0; m <
M
; ++m)
for
(
size_
t
m
0
=
0
;
m
0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
]
;
++
m
0
)
{
{
for(
int n
= 0;
n
<
N
; ++
n
)
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
]
;
++
m1
)
{
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
),
d_ms_ns
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
e_device_buf
.
FromDevice
(
e_m
s
_n
s
_device_result
.
mData
.
data
());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
)
?
0
:
1
;
#endif
}
}
return
0
;
return
0
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
View file @
f60ad8b9
...
@@ -501,7 +501,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
...
@@ -501,7 +501,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
a_kz_stride_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
+
NumDimK
-
1
];
a_kz_stride_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
+
NumDimK
-
1
];
}
}
for
(
index_t
i
=
0
;
i
<
Num
A
Tensor
;
++
i
)
for
(
index_t
i
=
0
;
i
<
Num
B
Tensor
;
++
i
)
{
{
b_nz_stride_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
-
1
];
b_nz_stride_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
-
1
];
b_kz_stride_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
+
NumDimK
-
1
];
b_kz_stride_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
+
NumDimK
-
1
];
...
@@ -697,8 +697,6 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
...
@@ -697,8 +697,6 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
});
});
// check vector load of Ds
// check vector load of Ds
// only support RowMajor for now
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
if
(
!
(
arg
.
ds_nz_stride_
[
i
]
==
1
&&
if
(
!
(
arg
.
ds_nz_stride_
[
i
]
==
1
&&
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
i
].
GetLength
(
I3
)
%
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
i
].
GetLength
(
I3
)
%
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment