Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
35e49f2d
Unverified
Commit
35e49f2d
authored
Aug 12, 2022
by
zjing14
Committed by
GitHub
Aug 12, 2022
Browse files
add g; fixed strides (#355)
parent
de60d290
Changes
4
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
279 additions
and
250 deletions
+279
-250
example/25_gemm_bias_e_permute/CMakeLists.txt
example/25_gemm_bias_e_permute/CMakeLists.txt
+2
-2
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
..._bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
+131
-111
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
..._bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
+141
-128
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...ce/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+5
-9
No files found.
example/25_gemm_bias_e_permute/CMakeLists.txt
View file @
35e49f2d
add_example_executable
(
example_gemm_bias_e_permute_m3n2_xdl_fp16 gemm_bias_e_permute_m3n2_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_bias_e_permute_
g1
m3n2
k1
_xdl_fp16 gemm_bias_e_permute_
g1
m3n2
k1
_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_bias_e_permute_m2n3_xdl_fp16 gemm_bias_e_permute_m2n3_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_bias_e_permute_
g1
m2n3
k1
_xdl_fp16 gemm_bias_e_permute_
g1
m2n3
k1
_xdl_fp16.cpp
)
example/25_gemm_bias_e_permute/gemm_bias_e_permute_m2n3_xdl_fp16.cpp
→
example/25_gemm_bias_e_permute/gemm_bias_e_permute_
g1
m2n3
k1
_xdl_fp16.cpp
View file @
35e49f2d
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -33,7 +35,7 @@ using DDataType = F16;
...
@@ -33,7 +35,7 @@ using DDataType = F16;
using
DsDataType
=
ck
::
Tuple
<
DDataType
>
;
using
DsDataType
=
ck
::
Tuple
<
DDataType
>
;
using
EDataType
=
F16
;
using
EDataType
=
F16
;
static
constexpr
ck
::
index_t
NumDimG
=
0
;
static
constexpr
ck
::
index_t
NumDimG
=
1
;
static
constexpr
ck
::
index_t
NumDimM
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
2
;
static
constexpr
ck
::
index_t
NumDimN
=
3
;
static
constexpr
ck
::
index_t
NumDimN
=
3
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
...
@@ -69,30 +71,31 @@ template <ck::index_t NumDimM,
...
@@ -69,30 +71,31 @@ template <ck::index_t NumDimM,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
CDEElementwiseOperation
,
ck
::
enable_if_t
<
NumDimM
==
2
&&
NumDimN
==
3
&&
NumDimK
==
1
,
bool
>
=
false
>
ck
::
enable_if_t
<
NumDimG
==
1
&&
NumDimM
==
2
&&
NumDimN
==
3
&&
NumDimK
==
1
,
bool
>
=
struct
ReferenceContraction_M2_N3_K1
:
public
ck
::
tensor_operation
::
device
::
BaseOperator
false
>
struct
ReferenceContraction_G1_M2_N3_K1
:
public
ck
::
tensor_operation
::
device
::
BaseOperator
{
{
// Argument
// Argument
struct
Argument
:
public
ck
::
tensor_operation
::
device
::
BaseArgument
struct
Argument
:
public
ck
::
tensor_operation
::
device
::
BaseArgument
{
{
Argument
(
const
Tensor
<
ADataType
>&
a_ms_ks
,
Argument
(
const
Tensor
<
ADataType
>&
a_
gs_
ms_ks
,
const
Tensor
<
BDataType
>&
b_ns_ks
,
const
Tensor
<
BDataType
>&
b_
gs_
ns_ks
,
Tensor
<
EDataType
>&
e_ms_ns
,
Tensor
<
EDataType
>&
e_
gs_
ms_ns
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
CDEElementwiseOperation
cde_element_op
)
:
a_ms_ks_
{
a_ms_ks
},
:
a_
gs_
ms_ks_
{
a_
gs_
ms_ks
},
b_ns_ks_
{
b_ns_ks
},
b_
gs_
ns_ks_
{
b_
gs_
ns_ks
},
e_ms_ns_
{
e_ms_ns
},
e_
gs_
ms_ns_
{
e_
gs_
ms_ns
},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
cde_element_op_
{
cde_element_op
}
{
{
}
}
const
Tensor
<
ADataType
>&
a_ms_ks_
;
const
Tensor
<
ADataType
>&
a_
gs_
ms_ks_
;
const
Tensor
<
BDataType
>&
b_ns_ks_
;
const
Tensor
<
BDataType
>&
b_
gs_
ns_ks_
;
Tensor
<
EDataType
>&
e_ms_ns_
;
Tensor
<
EDataType
>&
e_
gs_
ms_ns_
;
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
...
@@ -102,12 +105,12 @@ struct ReferenceContraction_M2_N3_K1 : public ck::tensor_operation::device::Base
...
@@ -102,12 +105,12 @@ struct ReferenceContraction_M2_N3_K1 : public ck::tensor_operation::device::Base
// Invoker
// Invoker
struct
Invoker
:
public
ck
::
tensor_operation
::
device
::
BaseInvoker
struct
Invoker
:
public
ck
::
tensor_operation
::
device
::
BaseInvoker
{
{
using
Argument
=
ReferenceContraction_M2_N3_K1
::
Argument
;
using
Argument
=
ReferenceContraction_
G1_
M2_N3_K1
::
Argument
;
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
auto
f_ms_ns
=
[
&
](
auto
m0
,
auto
m1
,
auto
n0
,
auto
n1
,
auto
n2
)
{
auto
f_
gs_
ms_ns
=
[
&
](
auto
g0
,
auto
m0
,
auto
m1
,
auto
n0
,
auto
n1
,
auto
n2
)
{
const
int
K0
=
arg
.
a_ms_ks_
.
mDesc
.
GetLengths
()[
2
];
const
int
K0
=
arg
.
a_
gs_
ms_ks_
.
mDesc
.
GetLengths
()[
3
];
AccDataType
v_acc
=
0
;
AccDataType
v_acc
=
0
;
...
@@ -117,9 +120,10 @@ struct ReferenceContraction_M2_N3_K1 : public ck::tensor_operation::device::Base
...
@@ -117,9 +120,10 @@ struct ReferenceContraction_M2_N3_K1 : public ck::tensor_operation::device::Base
AccDataType
v_b
;
AccDataType
v_b
;
arg
.
a_element_op_
(
arg
.
a_element_op_
(
v_a
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
a_ms_ks_
(
m0
,
m1
,
k0
)));
v_a
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
a_
gs_
ms_ks_
(
g0
,
m0
,
m1
,
k0
)));
arg
.
b_element_op_
(
arg
.
b_element_op_
(
v_b
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
b_ns_ks_
(
n0
,
n1
,
n2
,
k0
)));
v_b
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
b_gs_ns_ks_
(
g0
,
n0
,
n1
,
n2
,
k0
)));
v_acc
+=
v_a
*
v_b
;
v_acc
+=
v_a
*
v_b
;
}
}
...
@@ -128,15 +132,16 @@ struct ReferenceContraction_M2_N3_K1 : public ck::tensor_operation::device::Base
...
@@ -128,15 +132,16 @@ struct ReferenceContraction_M2_N3_K1 : public ck::tensor_operation::device::Base
arg
.
cde_element_op_
(
v_c
,
v_acc
);
arg
.
cde_element_op_
(
v_c
,
v_acc
);
arg
.
e_ms_ns_
(
m0
,
m1
,
n0
,
n1
,
n2
)
=
v_c
;
arg
.
e_
gs_
ms_ns_
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
)
=
v_c
;
};
};
make_ParallelTensorFunctor
(
f_ms_ns
,
make_ParallelTensorFunctor
(
f_gs_ms_ns
,
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
3
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
3
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
4
])(
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
4
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
...
@@ -160,14 +165,15 @@ struct ReferenceContraction_M2_N3_K1 : public ck::tensor_operation::device::Base
...
@@ -160,14 +165,15 @@ struct ReferenceContraction_M2_N3_K1 : public ck::tensor_operation::device::Base
return
true
;
return
true
;
}
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_ms_ks
,
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_
gs_
ms_ks
,
const
Tensor
<
BDataType
>&
b_ns_ks
,
const
Tensor
<
BDataType
>&
b_
gs_
ns_ks
,
Tensor
<
EDataType
>&
e_ms_ns
,
Tensor
<
EDataType
>&
e_
gs_
ms_ns
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
CDEElementwiseOperation
cde_element_op
)
{
{
return
Argument
{
a_ms_ks
,
b_ns_ks
,
e_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
return
Argument
{
a_gs_ms_ks
,
b_gs_ns_ks
,
e_gs_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -196,28 +202,31 @@ int main(int argc, char* argv[])
...
@@ -196,28 +202,31 @@ int main(int argc, char* argv[])
int
init_method
=
1
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
ck
::
index_t
G0
=
1
;
ck
::
index_t
M0
=
4
;
ck
::
index_t
M0
=
4
;
ck
::
index_t
M1
=
256
;
ck
::
index_t
M1
=
256
;
ck
::
index_t
N0
=
4
;
ck
::
index_t
N0
=
4
;
ck
::
index_t
N1
=
8
;
ck
::
index_t
N1
=
16
;
ck
::
index_t
N2
=
128
;
ck
::
index_t
N2
=
32
;
ck
::
index_t
K0
=
256
;
ck
::
index_t
K0
=
256
;
// A[M0, M1, M2, K0]
// A[M0, M1, M2, K0]
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
{
M0
,
M1
,
K0
};
std
::
vector
<
ck
::
index_t
>
a_
gs_
ms_ks_lengths
{
G0
,
M0
,
M1
,
K0
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
{
M1
*
K0
,
K0
,
1
};
std
::
vector
<
ck
::
index_t
>
a_
gs_
ms_ks_strides
{
M0
*
M1
*
K0
,
M1
*
K0
,
K0
,
1
};
// B[N0, N1, K0]
// B[N0, N1, K0]
std
::
vector
<
ck
::
index_t
>
b_ns_ks_lengths
{
N0
,
N1
,
N2
,
K0
};
std
::
vector
<
ck
::
index_t
>
b_
gs_
ns_ks_lengths
{
G0
,
N0
,
N1
,
N2
,
K0
};
std
::
vector
<
ck
::
index_t
>
b_ns_ks_strides
{
N1
*
N2
*
K0
,
N2
*
K0
,
K0
,
1
};
std
::
vector
<
ck
::
index_t
>
b_
gs_
ns_ks_strides
{
N0
*
N1
*
N2
*
K0
,
N1
*
N2
*
K0
,
N2
*
K0
,
K0
,
1
};
// D[N0, M0, N1, M1, N2]
// D[N0, M0, N1, M1, N2]
std
::
vector
<
ck
::
index_t
>
d_ms_ns_lengths
{
M0
,
M1
,
N0
,
N1
,
N2
};
std
::
vector
<
ck
::
index_t
>
d_
gs_
ms_ns_lengths
{
G0
,
M0
,
M1
,
N0
,
N1
,
N2
};
std
::
vector
<
ck
::
index_t
>
d_ms_ns_strides
{
0
,
0
,
N1
*
N2
,
N
1
,
1
};
std
::
vector
<
ck
::
index_t
>
d_
gs_
ms_ns_strides
{
N0
*
N1
*
N2
,
0
,
0
,
N1
*
N2
,
N
2
,
1
};
// E[N0, M0, N1, M1, N2]
// E[N0, M0, N1, M1, N2]
std
::
vector
<
ck
::
index_t
>
e_ms_ns_lengths
{
M0
,
M1
,
N0
,
N1
,
N2
};
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_lengths
{
G0
,
M0
,
M1
,
N0
,
N1
,
N2
};
std
::
vector
<
ck
::
index_t
>
e_ms_ns_strides
{
N1
*
M1
*
N2
,
N2
,
M0
*
N1
*
M1
*
N2
,
M1
*
N2
,
1
};
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_strides
{
M0
*
M1
*
N0
*
N1
*
N2
,
N1
*
M1
*
N2
,
N2
,
M0
*
N1
*
M1
*
N2
,
M1
*
N2
,
1
};
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
...
@@ -237,50 +246,51 @@ int main(int argc, char* argv[])
...
@@ -237,50 +246,51 @@ int main(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
Tensor
<
ADataType
>
a_ms_ks
(
Tensor
<
ADataType
>
a_
gs_
ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_lengths
.
begin
(),
a_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_
gs_
ms_ks_lengths
.
begin
(),
a_
gs_
ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_strides
.
begin
(),
a_ms_ks_strides
.
end
()));
std
::
vector
<
std
::
size_t
>
(
a_
gs_
ms_ks_strides
.
begin
(),
a_
gs_
ms_ks_strides
.
end
()));
Tensor
<
BDataType
>
b_ns_ks
(
Tensor
<
BDataType
>
b_
gs_
ns_ks
(
std
::
vector
<
std
::
size_t
>
(
b_ns_ks_lengths
.
begin
(),
b_ns_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_
gs_
ns_ks_lengths
.
begin
(),
b_
gs_
ns_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_ns_ks_strides
.
begin
(),
b_ns_ks_strides
.
end
()));
std
::
vector
<
std
::
size_t
>
(
b_
gs_
ns_ks_strides
.
begin
(),
b_
gs_
ns_ks_strides
.
end
()));
Tensor
<
DDataType
>
d_ms_ns
(
Tensor
<
DDataType
>
d_
gs_
ms_ns
(
std
::
vector
<
std
::
size_t
>
(
d_ms_ns_lengths
.
begin
(),
d_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
d_
gs_
ms_ns_lengths
.
begin
(),
d_
gs_
ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
d_ms_ns_strides
.
begin
(),
d_ms_ns_strides
.
end
()));
std
::
vector
<
std
::
size_t
>
(
d_
gs_
ms_ns_strides
.
begin
(),
d_
gs_
ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_ms_ns_host_result
(
Tensor
<
EDataType
>
e_
gs_
ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_
gs_
ms_ns_lengths
.
begin
(),
e_
gs_
ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
std
::
vector
<
std
::
size_t
>
(
e_
gs_
ms_ns_strides
.
begin
(),
e_
gs_
ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_ms_ns_device_result
(
Tensor
<
EDataType
>
e_
gs_
ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_
gs_
ms_ns_lengths
.
begin
(),
e_
gs_
ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
std
::
vector
<
std
::
size_t
>
(
e_
gs_
ms_ns_strides
.
begin
(),
e_
gs_
ms_ns_strides
.
end
()));
std
::
cout
<<
"a_ms_ks: "
<<
a_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_
gs_
ms_ks: "
<<
a_
gs_
ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_ns_ks: "
<<
b_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_
gs_
ns_ks: "
<<
b_
gs_
ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_ms_ns: "
<<
d_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_
gs_
ms_ns: "
<<
d_
gs_
ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_ms_ns: "
<<
e_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_
gs_
ms_ns: "
<<
e_
gs_
ms_ns_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
a_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
a_
gs_
ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
b_
gs_
ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_
gs_
ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
break
;
default:
default:
a_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_
gs_
ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_
gs_
ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_
gs_
ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
break
;
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_gs_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_ms_ks
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_
gs_
ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_ns_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_
gs_
ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_ms_ns
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_
gs_
ms_ns
.
mData
.
data
());
// set zero
// set zero
e_device_buf
.
SetZero
();
e_device_buf
.
SetZero
();
...
@@ -296,14 +306,14 @@ int main(int argc, char* argv[])
...
@@ -296,14 +306,14 @@ int main(int argc, char* argv[])
b_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
e_device_buf
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_
gs_
ms_ks_lengths
,
a_ms_ks_strides
,
a_
gs_
ms_ks_strides
,
b_ns_ks_lengths
,
b_
gs_
ns_ks_lengths
,
b_ns_ks_strides
,
b_
gs_
ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_
gs_
ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_
gs_
ms_ns_strides
},
e_ms_ns_lengths
,
e_
gs_
ms_ns_lengths
,
e_ms_ns_strides
,
e_
gs_
ms_ns_strides
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
cde_element_op
);
cde_element_op
);
...
@@ -317,18 +327,18 @@ int main(int argc, char* argv[])
...
@@ -317,18 +327,18 @@ int main(int argc, char* argv[])
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
ck
::
index
_t
M
=
std
::
accumulate
(
e_ms_ns_lengths
.
begin
(),
std
::
size
_t
M
=
std
::
accumulate
(
e_
gs_
ms_ns_lengths
.
begin
()
+
NumDimG
,
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
e_
gs_
ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
ck
::
index_t
{
1
},
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index
_t
N
=
std
::
accumulate
(
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
std
::
size
_t
N
=
std
::
accumulate
(
e_
gs_
ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
e_ms_ns_lengths
.
begin
()
+
NumDimM
+
NumDimN
,
e_
gs_
ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
+
NumDimN
,
ck
::
index_t
{
1
},
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index
_t
K
=
std
::
accumulate
(
a_ms_ks_lengths
.
begin
()
+
NumDimM
,
std
::
size
_t
K
=
std
::
accumulate
(
a_
gs_
ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
a_ms_ks_lengths
.
begin
()
+
NumDimM
+
NumDimK
,
a_
gs_
ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
+
NumDimK
,
ck
::
index_t
{
1
},
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
multiplies
<
ck
::
index_t
>
{});
...
@@ -343,15 +353,15 @@ int main(int argc, char* argv[])
...
@@ -343,15 +353,15 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op
.
GetTypeString
()
<<
std
::
endl
;
<<
op
.
GetTypeString
()
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_ms_ns_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_
gs_
ms_ns_device_result
.
mData
.
data
());
if
(
do_verification
)
if
(
do_verification
)
{
{
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
Tensor
<
CShuffleDataType
>
c_
gs_
ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_
gs_
ms_ns_lengths
.
begin
(),
e_
gs_
ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
std
::
vector
<
std
::
size_t
>
(
e_
gs_
ms_ns_strides
.
begin
(),
e_
gs_
ms_ns_strides
.
end
()));
using
ReferenceOpInstance
=
ReferenceContraction_M2_N3_K1
<
NumDimM
,
using
ReferenceOpInstance
=
ReferenceContraction_
G1_
M2_N3_K1
<
NumDimM
,
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
ADataType
,
ADataType
,
...
@@ -365,31 +375,41 @@ int main(int argc, char* argv[])
...
@@ -365,31 +375,41 @@ int main(int argc, char* argv[])
auto
ref_gemm
=
ReferenceOpInstance
{};
auto
ref_gemm
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_gs_ms_ks
,
a_ms_ks
,
b_ns_ks
,
c_ms_ns_host_result
,
a_element_op
,
b_element_op
,
PassThrough
{});
b_gs_ns_ks
,
c_gs_ms_ns_host_result
,
a_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
for
(
size_t
g0
=
0
;
g0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
g0
)
{
for
(
size_t
m0
=
0
;
m0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m0
)
{
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
for
(
size_t
m1
=
0
;
m1
<
e_
gs_
ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
m1
)
{
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
for
(
size_t
n0
=
0
;
n0
<
e_
gs_
ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n0
)
{
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
for
(
size_t
n1
=
0
;
n1
<
e_
gs_
ms_ns_host_result
.
mDesc
.
GetLengths
()[
4
];
++
n1
)
{
{
for
(
size_t
n2
=
0
;
n2
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
4
];
++
n2
)
for
(
size_t
n2
=
0
;
n2
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
5
];
++
n2
)
{
{
cde_element_op
(
e_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
,
n2
),
cde_element_op
(
e_gs_ms_ns_host_result
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
),
c_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
,
n2
),
c_gs_ms_ns_host_result
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
),
d_ms_ns
(
m0
,
m1
,
n0
,
n1
,
n2
));
d_gs_ms_ns
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
));
}
}
}
}
}
}
}
}
}
}
}
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
.
mData
,
e_ms_ns_host_result
.
mData
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
.
mData
,
e_gs_ms_ns_host_result
.
mData
)
?
0
:
1
;
}
}
return
0
;
return
0
;
...
...
example/25_gemm_bias_e_permute/gemm_bias_e_permute_m3n2_xdl_fp16.cpp
→
example/25_gemm_bias_e_permute/gemm_bias_e_permute_
g1
m3n2
k1
_xdl_fp16.cpp
View file @
35e49f2d
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
35e49f2d
...
@@ -500,11 +500,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
...
@@ -500,11 +500,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
std
::
array
<
long_index_t
,
NumDTensor
>
ds_offset
;
std
::
array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
NumDimG
>
0
)
ds_offset
[
i
]
=
ds_offset
[
i
]
=
ds_grid_desc_g_m_n_
[
i
].
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
ds_grid_desc_g_m_n_
[
i
].
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
else
ds_offset
[
i
]
=
0
;
});
});
return
ds_offset
;
return
ds_offset
;
...
@@ -512,10 +509,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
...
@@ -512,10 +509,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
{
if
constexpr
(
NumDimG
>
0
)
return
e_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
e_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
else
return
0
;
}
}
private:
private:
...
@@ -634,6 +628,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
...
@@ -634,6 +628,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
compute_ptr_offset_of_batch_
{
compute_ptr_offset_of_batch_
{
a_batch_stride_
,
b_batch_stride_
,
ds_grid_desc_g_m_n_
,
e_grid_desc_g_m_n_
}
a_batch_stride_
,
b_batch_stride_
,
ds_grid_desc_g_m_n_
,
e_grid_desc_g_m_n_
}
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
,
""
);
// populate pointer, batch stride, desc for Ds
// populate pointer, batch stride, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment