Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_onnxruntime
Commits
35e49f2d
You need to sign in or sign up before continuing.
Unverified
Commit
35e49f2d
authored
Aug 12, 2022
by
zjing14
Committed by
GitHub
Aug 12, 2022
Browse files
add g; fixed strides (#355)
parent
de60d290
Changes
4
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
279 additions
and
250 deletions
+279
-250
example/25_gemm_bias_e_permute/CMakeLists.txt
example/25_gemm_bias_e_permute/CMakeLists.txt
+2
-2
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
..._bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
+131
-111
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
..._bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
+141
-128
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...ce/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+5
-9
No files found.
example/25_gemm_bias_e_permute/CMakeLists.txt
View file @
35e49f2d
add_example_executable
(
example_gemm_bias_e_permute_m3n2_xdl_fp16 gemm_bias_e_permute_m3n2_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_bias_e_permute_m2n3_xdl_fp16 gemm_bias_e_permute_m2n3_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_bias_e_permute_
g1
m3n2
k1
_xdl_fp16 gemm_bias_e_permute_
g1
m3n2
k1
_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_bias_e_permute_
g1
m2n3
k1
_xdl_fp16 gemm_bias_e_permute_
g1
m2n3
k1
_xdl_fp16.cpp
)
example/25_gemm_bias_e_permute/gemm_bias_e_permute_m2n3_xdl_fp16.cpp
→
example/25_gemm_bias_e_permute/gemm_bias_e_permute_
g1
m2n3
k1
_xdl_fp16.cpp
View file @
35e49f2d
...
...
@@ -16,6 +16,8 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -33,7 +35,7 @@ using DDataType = F16;
using
DsDataType
=
ck
::
Tuple
<
DDataType
>
;
using
EDataType
=
F16
;
static
constexpr
ck
::
index_t
NumDimG
=
0
;
static
constexpr
ck
::
index_t
NumDimG
=
1
;
static
constexpr
ck
::
index_t
NumDimM
=
2
;
static
constexpr
ck
::
index_t
NumDimN
=
3
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
...
...
@@ -69,30 +71,31 @@ template <ck::index_t NumDimM,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
ck
::
enable_if_t
<
NumDimM
==
2
&&
NumDimN
==
3
&&
NumDimK
==
1
,
bool
>
=
false
>
struct
ReferenceContraction_M2_N3_K1
:
public
ck
::
tensor_operation
::
device
::
BaseOperator
ck
::
enable_if_t
<
NumDimG
==
1
&&
NumDimM
==
2
&&
NumDimN
==
3
&&
NumDimK
==
1
,
bool
>
=
false
>
struct
ReferenceContraction_G1_M2_N3_K1
:
public
ck
::
tensor_operation
::
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
ck
::
tensor_operation
::
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ADataType
>&
a_ms_ks
,
const
Tensor
<
BDataType
>&
b_ns_ks
,
Tensor
<
EDataType
>&
e_ms_ns
,
Argument
(
const
Tensor
<
ADataType
>&
a_
gs_
ms_ks
,
const
Tensor
<
BDataType
>&
b_
gs_
ns_ks
,
Tensor
<
EDataType
>&
e_
gs_
ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
a_ms_ks_
{
a_ms_ks
},
b_ns_ks_
{
b_ns_ks
},
e_ms_ns_
{
e_ms_ns
},
:
a_
gs_
ms_ks_
{
a_
gs_
ms_ks
},
b_
gs_
ns_ks_
{
b_
gs_
ns_ks
},
e_
gs_
ms_ns_
{
e_
gs_
ms_ns
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
}
const
Tensor
<
ADataType
>&
a_ms_ks_
;
const
Tensor
<
BDataType
>&
b_ns_ks_
;
Tensor
<
EDataType
>&
e_ms_ns_
;
const
Tensor
<
ADataType
>&
a_
gs_
ms_ks_
;
const
Tensor
<
BDataType
>&
b_
gs_
ns_ks_
;
Tensor
<
EDataType
>&
e_
gs_
ms_ns_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
...
...
@@ -102,12 +105,12 @@ struct ReferenceContraction_M2_N3_K1 : public ck::tensor_operation::device::Base
// Invoker
struct
Invoker
:
public
ck
::
tensor_operation
::
device
::
BaseInvoker
{
using
Argument
=
ReferenceContraction_M2_N3_K1
::
Argument
;
using
Argument
=
ReferenceContraction_
G1_
M2_N3_K1
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
auto
f_ms_ns
=
[
&
](
auto
m0
,
auto
m1
,
auto
n0
,
auto
n1
,
auto
n2
)
{
const
int
K0
=
arg
.
a_ms_ks_
.
mDesc
.
GetLengths
()[
2
];
auto
f_
gs_
ms_ns
=
[
&
](
auto
g0
,
auto
m0
,
auto
m1
,
auto
n0
,
auto
n1
,
auto
n2
)
{
const
int
K0
=
arg
.
a_
gs_
ms_ks_
.
mDesc
.
GetLengths
()[
3
];
AccDataType
v_acc
=
0
;
...
...
@@ -117,9 +120,10 @@ struct ReferenceContraction_M2_N3_K1 : public ck::tensor_operation::device::Base
AccDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
a_ms_ks_
(
m0
,
m1
,
k0
)));
v_a
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
a_
gs_
ms_ks_
(
g0
,
m0
,
m1
,
k0
)));
arg
.
b_element_op_
(
v_b
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
b_ns_ks_
(
n0
,
n1
,
n2
,
k0
)));
v_b
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
b_gs_ns_ks_
(
g0
,
n0
,
n1
,
n2
,
k0
)));
v_acc
+=
v_a
*
v_b
;
}
...
...
@@ -128,15 +132,16 @@ struct ReferenceContraction_M2_N3_K1 : public ck::tensor_operation::device::Base
arg
.
cde_element_op_
(
v_c
,
v_acc
);
arg
.
e_ms_ns_
(
m0
,
m1
,
n0
,
n1
,
n2
)
=
v_c
;
arg
.
e_
gs_
ms_ns_
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
)
=
v_c
;
};
make_ParallelTensorFunctor
(
f_ms_ns
,
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
3
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
4
])(
make_ParallelTensorFunctor
(
f_gs_ms_ns
,
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
3
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
4
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
@@ -160,14 +165,15 @@ struct ReferenceContraction_M2_N3_K1 : public ck::tensor_operation::device::Base
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_ms_ks
,
const
Tensor
<
BDataType
>&
b_ns_ks
,
Tensor
<
EDataType
>&
e_ms_ns
,
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_
gs_
ms_ks
,
const
Tensor
<
BDataType
>&
b_
gs_
ns_ks
,
Tensor
<
EDataType
>&
e_
gs_
ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
a_ms_ks
,
b_ns_ks
,
e_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
return
Argument
{
a_gs_ms_ks
,
b_gs_ns_ks
,
e_gs_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -196,28 +202,31 @@ int main(int argc, char* argv[])
int
init_method
=
1
;
bool
time_kernel
=
false
;
ck
::
index_t
G0
=
1
;
ck
::
index_t
M0
=
4
;
ck
::
index_t
M1
=
256
;
ck
::
index_t
N0
=
4
;
ck
::
index_t
N1
=
8
;
ck
::
index_t
N2
=
128
;
ck
::
index_t
N1
=
16
;
ck
::
index_t
N2
=
32
;
ck
::
index_t
K0
=
256
;
// A[M0, M1, M2, K0]
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
{
M0
,
M1
,
K0
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
{
M1
*
K0
,
K0
,
1
};
std
::
vector
<
ck
::
index_t
>
a_
gs_
ms_ks_lengths
{
G0
,
M0
,
M1
,
K0
};
std
::
vector
<
ck
::
index_t
>
a_
gs_
ms_ks_strides
{
M0
*
M1
*
K0
,
M1
*
K0
,
K0
,
1
};
// B[N0, N1, K0]
std
::
vector
<
ck
::
index_t
>
b_ns_ks_lengths
{
N0
,
N1
,
N2
,
K0
};
std
::
vector
<
ck
::
index_t
>
b_ns_ks_strides
{
N1
*
N2
*
K0
,
N2
*
K0
,
K0
,
1
};
std
::
vector
<
ck
::
index_t
>
b_
gs_
ns_ks_lengths
{
G0
,
N0
,
N1
,
N2
,
K0
};
std
::
vector
<
ck
::
index_t
>
b_
gs_
ns_ks_strides
{
N0
*
N1
*
N2
*
K0
,
N1
*
N2
*
K0
,
N2
*
K0
,
K0
,
1
};
// D[N0, M0, N1, M1, N2]
std
::
vector
<
ck
::
index_t
>
d_ms_ns_lengths
{
M0
,
M1
,
N0
,
N1
,
N2
};
std
::
vector
<
ck
::
index_t
>
d_ms_ns_strides
{
0
,
0
,
N1
*
N2
,
N
1
,
1
};
std
::
vector
<
ck
::
index_t
>
d_
gs_
ms_ns_lengths
{
G0
,
M0
,
M1
,
N0
,
N1
,
N2
};
std
::
vector
<
ck
::
index_t
>
d_
gs_
ms_ns_strides
{
N0
*
N1
*
N2
,
0
,
0
,
N1
*
N2
,
N
2
,
1
};
// E[N0, M0, N1, M1, N2]
std
::
vector
<
ck
::
index_t
>
e_ms_ns_lengths
{
M0
,
M1
,
N0
,
N1
,
N2
};
std
::
vector
<
ck
::
index_t
>
e_ms_ns_strides
{
N1
*
M1
*
N2
,
N2
,
M0
*
N1
*
M1
*
N2
,
M1
*
N2
,
1
};
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_lengths
{
G0
,
M0
,
M1
,
N0
,
N1
,
N2
};
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_strides
{
M0
*
M1
*
N0
*
N1
*
N2
,
N1
*
M1
*
N2
,
N2
,
M0
*
N1
*
M1
*
N2
,
M1
*
N2
,
1
};
if
(
argc
==
1
)
{
...
...
@@ -237,50 +246,51 @@ int main(int argc, char* argv[])
exit
(
0
);
}
Tensor
<
ADataType
>
a_ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_lengths
.
begin
(),
a_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_strides
.
begin
(),
a_ms_ks_strides
.
end
()));
Tensor
<
BDataType
>
b_ns_ks
(
std
::
vector
<
std
::
size_t
>
(
b_ns_ks_lengths
.
begin
(),
b_ns_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_ns_ks_strides
.
begin
(),
b_ns_ks_strides
.
end
()));
Tensor
<
DDataType
>
d_ms_ns
(
std
::
vector
<
std
::
size_t
>
(
d_ms_ns_lengths
.
begin
(),
d_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
d_ms_ns_strides
.
begin
(),
d_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
std
::
cout
<<
"a_ms_ks: "
<<
a_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_ns_ks: "
<<
b_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_ms_ns: "
<<
d_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_ms_ns: "
<<
e_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
Tensor
<
ADataType
>
a_
gs_
ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_
gs_
ms_ks_lengths
.
begin
(),
a_
gs_
ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_
gs_
ms_ks_strides
.
begin
(),
a_
gs_
ms_ks_strides
.
end
()));
Tensor
<
BDataType
>
b_
gs_
ns_ks
(
std
::
vector
<
std
::
size_t
>
(
b_
gs_
ns_ks_lengths
.
begin
(),
b_
gs_
ns_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_
gs_
ns_ks_strides
.
begin
(),
b_
gs_
ns_ks_strides
.
end
()));
Tensor
<
DDataType
>
d_
gs_
ms_ns
(
std
::
vector
<
std
::
size_t
>
(
d_
gs_
ms_ns_lengths
.
begin
(),
d_
gs_
ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
d_
gs_
ms_ns_strides
.
begin
(),
d_
gs_
ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_
gs_
ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_
gs_
ms_ns_lengths
.
begin
(),
e_
gs_
ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_
gs_
ms_ns_strides
.
begin
(),
e_
gs_
ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_
gs_
ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
e_
gs_
ms_ns_lengths
.
begin
(),
e_
gs_
ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_
gs_
ms_ns_strides
.
begin
(),
e_
gs_
ms_ns_strides
.
end
()));
std
::
cout
<<
"a_
gs_
ms_ks: "
<<
a_
gs_
ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_
gs_
ns_ks: "
<<
b_
gs_
ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_
gs_
ms_ns: "
<<
d_
gs_
ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_
gs_
ms_ns: "
<<
e_
gs_
ms_ns_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
a_
gs_
ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_
gs_
ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_
gs_
ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
default:
a_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
a_
gs_
ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_
gs_
ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_
gs_
ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_gs_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_ms_ns
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_
gs_
ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_
gs_
ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_
gs_
ms_ns
.
mData
.
data
());
// set zero
e_device_buf
.
SetZero
();
...
...
@@ -296,14 +306,14 @@ int main(int argc, char* argv[])
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_
gs_
ms_ks_lengths
,
a_
gs_
ms_ks_strides
,
b_
gs_
ns_ks_lengths
,
b_
gs_
ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_
gs_
ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_
gs_
ms_ns_strides
},
e_
gs_
ms_ns_lengths
,
e_
gs_
ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
...
...
@@ -317,18 +327,18 @@ int main(int argc, char* argv[])
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
ck
::
index
_t
M
=
std
::
accumulate
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
std
::
size
_t
M
=
std
::
accumulate
(
e_
gs_
ms_ns_lengths
.
begin
()
+
NumDimG
,
e_
gs_
ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index
_t
N
=
std
::
accumulate
(
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
e_ms_ns_lengths
.
begin
()
+
NumDimM
+
NumDimN
,
std
::
size
_t
N
=
std
::
accumulate
(
e_
gs_
ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
e_
gs_
ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
+
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index
_t
K
=
std
::
accumulate
(
a_ms_ks_lengths
.
begin
()
+
NumDimM
,
a_ms_ks_lengths
.
begin
()
+
NumDimM
+
NumDimK
,
std
::
size
_t
K
=
std
::
accumulate
(
a_
gs_
ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
a_
gs_
ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
+
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
...
...
@@ -343,15 +353,15 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op
.
GetTypeString
()
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_ms_ns_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_
gs_
ms_ns_device_result
.
mData
.
data
());
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
Tensor
<
CShuffleDataType
>
c_
gs_
ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_
gs_
ms_ns_lengths
.
begin
(),
e_
gs_
ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_
gs_
ms_ns_strides
.
begin
(),
e_
gs_
ms_ns_strides
.
end
()));
using
ReferenceOpInstance
=
ReferenceContraction_M2_N3_K1
<
NumDimM
,
using
ReferenceOpInstance
=
ReferenceContraction_
G1_
M2_N3_K1
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
...
...
@@ -365,31 +375,41 @@ int main(int argc, char* argv[])
auto
ref_gemm
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_ms_ks
,
b_ns_ks
,
c_ms_ns_host_result
,
a_element_op
,
b_element_op
,
PassThrough
{});
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_gs_ms_ks
,
b_gs_ns_ks
,
c_gs_ms_ns_host_result
,
a_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
for
(
size_t
g0
=
0
;
g0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
g0
)
{
for
(
size_t
m0
=
0
;
m0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
for
(
size_t
m1
=
0
;
m1
<
e_
gs_
ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
for
(
size_t
n0
=
0
;
n0
<
e_
gs_
ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
for
(
size_t
n1
=
0
;
n1
<
e_
gs_
ms_ns_host_result
.
mDesc
.
GetLengths
()[
4
];
++
n1
)
{
for
(
size_t
n2
=
0
;
n2
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
4
];
++
n2
)
for
(
size_t
n2
=
0
;
n2
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
5
];
++
n2
)
{
cde_element_op
(
e_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
,
n2
),
c_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
,
n2
),
d_ms_ns
(
m0
,
m1
,
n0
,
n1
,
n2
));
cde_element_op
(
e_gs_ms_ns_host_result
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
),
c_gs_ms_ns_host_result
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
),
d_gs_ms_ns
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
));
}
}
}
}
}
}
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
.
mData
,
e_ms_ns_host_result
.
mData
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
.
mData
,
e_gs_ms_ns_host_result
.
mData
)
?
0
:
1
;
}
return
0
;
...
...
example/25_gemm_bias_e_permute/gemm_bias_e_permute_m3n2_xdl_fp16.cpp
→
example/25_gemm_bias_e_permute/gemm_bias_e_permute_
g1
m3n2
k1
_xdl_fp16.cpp
View file @
35e49f2d
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
35e49f2d
...
...
@@ -500,11 +500,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
std
::
array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
NumDimG
>
0
)
ds_offset
[
i
]
=
ds_grid_desc_g_m_n_
[
i
].
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
else
ds_offset
[
i
]
=
0
;
});
return
ds_offset
;
...
...
@@ -512,10 +509,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
if
constexpr
(
NumDimG
>
0
)
return
e_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
else
return
0
;
}
private:
...
...
@@ -634,6 +628,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
compute_ptr_offset_of_batch_
{
a_batch_stride_
,
b_batch_stride_
,
ds_grid_desc_g_m_n_
,
e_grid_desc_g_m_n_
}
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
,
""
);
// populate pointer, batch stride, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment