Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7a3b49e5
Commit
7a3b49e5
authored
Jun 25, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into contraction
parents
e07b3d8e
d3051d75
Changes
592
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
213 additions
and
604 deletions
+213
-604
library/src/obselete_driver_offline/gemm_driver_offline.cpp
library/src/obselete_driver_offline/gemm_driver_offline.cpp
+0
-456
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+19
-38
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
..._batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
..._batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
..._batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
..._batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
...ice_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
...ice_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
...ice_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
...ice_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
..._batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp
..._batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
..._batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
..._batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
+10
-5
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
+17
-15
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
+17
-15
No files found.
library/src/obselete_driver_offline/gemm_driver_offline.cpp
deleted
100644 → 0
View file @
e07b3d8e
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp"
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdlops_mk_kn_mn.hpp"
#include "device_gemm_xdlops_mk_nk_mn.hpp"
#include "device_gemm_xdlops_km_kn_mn.hpp"
#include "device_gemm_xdlops_km_nk_mn.hpp"
#include "device_gemm_xdlops_mk_kn_nm.hpp"
#include "device_gemm_xdlops_mk_nk_nm.hpp"
#include "device_gemm_xdlops_km_kn_nm.hpp"
#include "device_gemm_xdlops_km_nk_nm.hpp"
#define USE_GEMM_XDL_MK_KN_MN 1
#define USE_GEMM_XDL_MK_NK_MN 1
#define USE_GEMM_XDL_KM_KN_MN 1
#define USE_GEMM_XDL_KM_NK_MN 1
#define USE_GEMM_XDL_MK_KN_NM 0
#define USE_GEMM_XDL_MK_NK_NM 0
#define USE_GEMM_XDL_KM_KN_NM 0
#define USE_GEMM_XDL_KM_NK_NM 0
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
KM_KN_MN
,
// 2
KM_NK_MN
,
// 3
MK_KN_NM
,
// 4
MK_NK_NM
,
// 5
KM_KN_NM
,
// 6
KM_NK_NM
// 7
};
enum
struct
GemmAlgo
{
Xdl_MK_KN_MN
,
// 0
Xdl_MK_NK_MN
,
// 1
Xdl_KM_KN_MN
,
// 2
Xdl_KM_NK_MN
,
// 3
Xdl_MK_KN_NM
,
// 4
Xdl_MK_NK_NM
,
// 5
Xdl_KM_KN_NM
,
// 6
Xdl_KM_NK_NM
,
// 7
};
template
<
typename
AType
,
typename
BType
,
typename
CType
>
void
host_gemm
(
const
Tensor
<
AType
>&
a
,
const
Tensor
<
BType
>&
b
,
Tensor
<
CType
>&
c
,
const
GemmMatrixLayout
layout
)
{
if
(
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
a
.
mDesc
.
GetLengths
()[
1
];
double
v
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a
(
m
,
k
))
*
static_cast
<
const
double
>
(
b
(
k
,
n
));
}
c
(
m
,
n
)
=
v
;
};
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
auto
f_mk_nk_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
a
.
mDesc
.
GetLengths
()[
1
];
double
v
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a
(
m
,
k
))
*
static_cast
<
const
double
>
(
b
(
n
,
k
));
}
c
(
m
,
n
)
=
v
;
};
make_ParallelTensorFunctor
(
f_mk_nk_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
auto
f_km_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
a
.
mDesc
.
GetLengths
()[
0
];
double
v
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a
(
k
,
m
))
*
static_cast
<
const
double
>
(
b
(
k
,
n
));
}
c
(
m
,
n
)
=
v
;
};
make_ParallelTensorFunctor
(
f_km_kn_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
auto
f_km_nk_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
a
.
mDesc
.
GetLengths
()[
0
];
double
v
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a
(
k
,
m
))
*
static_cast
<
const
double
>
(
b
(
n
,
k
));
}
c
(
m
,
n
)
=
v
;
};
make_ParallelTensorFunctor
(
f_km_nk_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
GemmMatrixLayout
::
MK_KN_NM
)
{
auto
f_mk_kn_nm
=
[
&
](
auto
n
,
auto
m
)
{
const
int
K
=
a
.
mDesc
.
GetLengths
()[
1
];
double
v
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a
(
m
,
k
))
*
static_cast
<
const
double
>
(
b
(
k
,
n
));
}
c
(
n
,
m
)
=
v
;
};
make_ParallelTensorFunctor
(
f_mk_kn_nm
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
GemmMatrixLayout
::
MK_NK_NM
)
{
auto
f_mk_nk_nm
=
[
&
](
auto
n
,
auto
m
)
{
const
int
K
=
a
.
mDesc
.
GetLengths
()[
1
];
double
v
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a
(
m
,
k
))
*
static_cast
<
const
double
>
(
b
(
n
,
k
));
}
c
(
n
,
m
)
=
v
;
};
make_ParallelTensorFunctor
(
f_mk_nk_nm
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
GemmMatrixLayout
::
KM_KN_NM
)
{
auto
f_km_kn_nm
=
[
&
](
auto
n
,
auto
m
)
{
const
int
K
=
a
.
mDesc
.
GetLengths
()[
0
];
double
v
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a
(
k
,
m
))
*
static_cast
<
const
double
>
(
b
(
k
,
n
));
}
c
(
n
,
m
)
=
v
;
};
make_ParallelTensorFunctor
(
f_km_kn_nm
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
GemmMatrixLayout
::
KM_NK_NM
)
{
auto
f_km_nk_nm
=
[
&
](
auto
n
,
auto
m
)
{
const
int
K
=
a
.
mDesc
.
GetLengths
()[
0
];
double
v
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a
(
k
,
m
))
*
static_cast
<
const
double
>
(
b
(
n
,
k
));
}
c
(
n
,
m
)
=
v
;
};
make_ParallelTensorFunctor
(
f_km_nk_nm
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
else
{
throw
std
::
runtime_error
(
"wrong! not supported layout"
);
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
ck
;
if
(
argc
!=
12
)
{
printf
(
"arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"rest: M, N, K
\n
"
);
printf
(
"debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01
\n
"
);
exit
(
1
);
}
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
1
]));
const
auto
algo
=
static_cast
<
GemmAlgo
>
(
std
::
stoi
(
argv
[
2
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
3
]);
const
int
init_method
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
const
index_t
M
=
std
::
stoi
(
argv
[
7
]);
const
index_t
N
=
std
::
stoi
(
argv
[
8
]);
const
index_t
K
=
std
::
stoi
(
argv
[
9
]);
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
=
std
::
stoi
(
argv
[
10
]);
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
=
std
::
stoi
(
argv
[
11
]);
#if 0
using ab_data_t = float;
using acc_data_t = float;
using c_data_t = float;
#elif
1
using
ab_data_t
=
half_t
;
using
acc_data_t
=
float
;
using
c_data_t
=
half_t
;
#elif 1
using
ab_data_t
=
int8_t
;
using
acc_data_t
=
int32_t
;
using
c_data_t
=
int8_t
;
#endif
std
::
vector
<
std
::
size_t
>
a_lengths_host
(
2
),
b_lengths_host
(
2
),
c_lengths_host
(
2
);
std
::
vector
<
std
::
size_t
>
a_strides_host
(
2
),
b_strides_host
(
2
),
c_strides_host
(
2
);
// A
if
(
layout
==
GemmMatrixLayout
::
MK_KN_MN
||
layout
==
GemmMatrixLayout
::
MK_NK_MN
||
layout
==
GemmMatrixLayout
::
MK_KN_NM
||
layout
==
GemmMatrixLayout
::
MK_NK_NM
)
{
a_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
else
{
a_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
// B
if
(
layout
==
GemmMatrixLayout
::
MK_NK_MN
||
layout
==
GemmMatrixLayout
::
KM_NK_MN
||
layout
==
GemmMatrixLayout
::
MK_NK_NM
||
layout
==
GemmMatrixLayout
::
KM_NK_NM
)
{
b_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
else
{
b_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
// C
if
(
layout
==
GemmMatrixLayout
::
MK_KN_MN
||
layout
==
GemmMatrixLayout
::
KM_KN_MN
||
layout
==
GemmMatrixLayout
::
MK_NK_MN
||
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
c_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
else
{
c_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
Tensor
<
ab_data_t
>
a
(
a_lengths_host
,
a_strides_host
);
Tensor
<
ab_data_t
>
b
(
b_lengths_host
,
b_strides_host
);
Tensor
<
c_data_t
>
c_host
(
c_lengths_host
,
c_strides_host
);
Tensor
<
c_data_t
>
c_device
(
c_lengths_host
,
c_strides_host
);
std
::
cout
<<
"layout: "
<<
layout
<<
std
::
endl
;
ostream_HostTensorDescriptor
(
a
.
mDesc
,
std
::
cout
<<
"a: "
);
ostream_HostTensorDescriptor
(
b
.
mDesc
,
std
::
cout
<<
"b: "
);
ostream_HostTensorDescriptor
(
c_host
.
mDesc
,
std
::
cout
<<
"c: "
);
std
::
size_t
num_thread
=
1
;
switch
(
init_method
)
{
case
0
:
// no initialization
break
;
case
1
:
a
.
GenerateTensorValue
(
GeneratorTensor_1
<
ab_data_t
>
{},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_1
<
ab_data_t
>
{},
num_thread
);
break
;
case
2
:
a
.
GenerateTensorValue
(
GeneratorTensor_1
<
ab_data_t
>
{},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_2
<
ab_data_t
>
{
-
5
,
5
},
num_thread
);
break
;
case
3
:
a
.
GenerateTensorValue
(
GeneratorTensor_2
<
ab_data_t
>
{
-
5
,
5
},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_1
<
ab_data_t
>
{},
num_thread
);
break
;
case
4
:
a
.
GenerateTensorValue
(
GeneratorTensor_2
<
ab_data_t
>
{
-
5
,
5
},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_2
<
ab_data_t
>
{
-
5
,
5
},
num_thread
);
break
;
default:
a
.
GenerateTensorValue
(
GeneratorTensor_3
<
ab_data_t
>
{
0.0
,
1.0
},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_3
<
ab_data_t
>
{
-
0.5
,
0.5
},
num_thread
);
}
#if USE_GEMM_XDL_MK_KN_MN
if
(
algo
==
GemmAlgo
::
Xdl_MK_KN_MN
)
{
if
(
layout
!=
GemmMatrixLayout
::
MK_KN_MN
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_mk_kn_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_MK_NK_MN
if
(
algo
==
GemmAlgo
::
Xdl_MK_NK_MN
)
{
if
(
layout
!=
GemmMatrixLayout
::
MK_NK_MN
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_mk_nk_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_KM_KN_MN
if
(
algo
==
GemmAlgo
::
Xdl_KM_KN_MN
)
{
if
(
layout
!=
GemmMatrixLayout
::
KM_KN_MN
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_km_kn_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_KM_NK_MN
if
(
algo
==
GemmAlgo
::
Xdl_KM_NK_MN
)
{
if
(
layout
!=
GemmMatrixLayout
::
KM_NK_MN
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_km_nk_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_MK_KN_NM
if
(
algo
==
GemmAlgo
::
Xdl_MK_KN_NM
)
{
if
(
layout
!=
GemmMatrixLayout
::
MK_KN_NM
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_mk_kn_nm
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_MK_NK_NM
if
(
algo
==
GemmAlgo
::
Xdl_MK_NK_NM
)
{
if
(
layout
!=
GemmMatrixLayout
::
MK_NK_NM
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_mk_nk_nm
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_KM_KN_NM
if
(
algo
==
GemmAlgo
::
Xdl_KM_KN_NM
)
{
if
(
layout
!=
GemmMatrixLayout
::
KM_KN_NM
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_km_kn_nm
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_KM_NK_NM
if
(
algo
==
GemmAlgo
::
Xdl_KM_NK_NM
)
{
if
(
layout
!=
GemmMatrixLayout
::
KM_NK_NM
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_km_nk_nm
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
if
(
do_verification
)
{
host_gemm
(
a
,
b
,
c_host
,
layout
);
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
c_host
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_device
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
7a3b49e5
include_directories
(
BEFORE
${
PROJECT_SOURCE_DIR
}
/include/ck
${
PROJECT_SOURCE_DIR
}
/include/ck/utility
${
PROJECT_SOURCE_DIR
}
/include/ck/host_utility
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_description
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor
${
PROJECT_SOURCE_DIR
}
/include/ck/problem_transform
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_operation/gpu/device
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_operation/gpu/grid
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_operation/gpu/block
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_operation/gpu/warp
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_operation/gpu/thread
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_operation/gpu/element
${
PROJECT_SOURCE_DIR
}
/library/include/ck/library/host_tensor
${
PROJECT_SOURCE_DIR
}
/library/include/ck/library/host
${
PROJECT_SOURCE_DIR
}
/library/include/ck/library/tensor_operation_instance
${
PROJECT_SOURCE_DIR
}
/library/include/ck/library/tensor_operation_instance/gpu/reduce
${
PROJECT_SOURCE_DIR
}
/external/include/half
)
function
(
add_instance_library INSTANCE_NAME
)
message
(
"adding instance
${
INSTANCE_NAME
}
"
)
add_library
(
${
INSTANCE_NAME
}
OBJECT
${
ARGN
}
)
...
...
@@ -30,19 +10,20 @@ add_subdirectory(gemm_bias2d)
add_subdirectory
(
gemm_bias_relu
)
add_subdirectory
(
gemm_bias_relu_add
)
add_subdirectory
(
gemm_reduce
)
add_subdirectory
(
gemm_bias_add_reduce
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
conv1d_fwd
)
add_subdirectory
(
conv2d_fwd
)
add_subdirectory
(
conv3d_fwd
)
add_subdirectory
(
conv2d_fwd_bias_relu
)
add_subdirectory
(
conv2d_fwd_bias_relu_add
)
add_subdirectory
(
conv2d_fwd_bias_relu_atomic_add
)
add_subdirectory
(
conv2d_bwd_data
)
add_subdirectory
(
reduce
)
add_subdirectory
(
convnd_bwd_data
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
conv2d_bwd_weight
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
gemm_add_add_fastgelu
)
add_library
(
device_operations STATIC
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
...
...
@@ -51,7 +32,6 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_atomic_add_instance>
$<TARGET_OBJECTS:device_gemm_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
...
...
@@ -62,7 +42,7 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_conv3d_fwd_instance>
device_conv2d.cpp
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
)
add_library
(
composablekernels::device_operations ALIAS device_operations
)
...
...
@@ -70,8 +50,8 @@ add_library(composablekernels::device_operations ALIAS device_operations)
set
(
DEV_OPS_INC_DIRS
${
PROJECT_SOURCE_DIR
}
/include/ck/
${
PROJECT_SOURCE_DIR
}
/library/include/ck/
${
PROJECT_SOURCE_DIR
}
/external/include/
)
target_compile_features
(
device_operations PUBLIC
)
set_target_properties
(
device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
device_operations PUBLIC
...
...
@@ -90,15 +70,16 @@ target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/host>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/reduce>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/half>
)
#once new arches are enabled make this an option on the main cmake file
# and pass down here to be exported
target_compile_options
(
device_operations
PRIVATE --offload-arch=gfx908
target_compile_options
(
device_operations PRIVATE
--offload-arch=gfx908
--offload-arch=gfx90a
)
# install(TARGETS device_operations LIBRARY DESTINATION lib)
install
(
TARGETS device_operations
EXPORT device_operationsTargets
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -21,11 +26,11 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
false
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Identity
,
Identity
>
;
...
...
@@ -62,12 +67,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in
>
;
void
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
View file @
7a3b49e5
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "device_operation_instance.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -21,11 +26,11 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
false
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Identity
,
Identity
>
;
...
...
@@ -62,12 +67,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in
>
;
void
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
Prev
1
…
14
15
16
17
18
19
20
21
22
…
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment