Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
88b978c5
Commit
88b978c5
authored
Jun 03, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
e4112de7
6fb1f4e0
Changes
40
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1283 additions
and
285 deletions
+1283
-285
include/ck_tile/host/stream_config.hpp
include/ck_tile/host/stream_config.hpp
+17
-0
include/ck_tile/host/timer.hpp
include/ck_tile/host/timer.hpp
+79
-0
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
+3
-3
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+2
-2
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
+55
-4
library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt
...nsor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt
+6
-1
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
..._gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+58
-0
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
..._xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+58
-0
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
...multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+1
-107
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
..._xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+59
-0
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
..._multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+58
-0
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
...i_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+58
-0
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
..._multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+1
-106
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
...i_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+58
-0
pyproject.toml
pyproject.toml
+36
-0
python/ck4inductor/__init__.py
python/ck4inductor/__init__.py
+0
-0
python/ck4inductor/universal_gemm/gen_instances.py
python/ck4inductor/universal_gemm/gen_instances.py
+570
-0
python/ck4inductor/universal_gemm/op.py
python/ck4inductor/universal_gemm/op.py
+95
-0
python/ck4inductor/util.py
python/ck4inductor/util.py
+7
-0
test/position_embedding/position_embedding.cpp
test/position_embedding/position_embedding.cpp
+62
-62
No files found.
include/ck_tile/host/stream_config.hpp
View file @
88b978c5
...
@@ -6,6 +6,22 @@
...
@@ -6,6 +6,22 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
namespace
ck_tile
{
namespace
ck_tile
{
/*
* construct this structure with behavior as:
*
* // create stream config with default stream(NULL), and not timing the kernel
* stream_config s = stream_config{};
*
* // create stream config with _some_stream_id_, and not timing the kernel
* stream_config s = stream_config{_some_stream_id_};
*
* // create stream config with _some_stream_id_, and benchmark with warmup/repeat as default
* stream_config s = stream_config{_some_stream_id_, true};
*
* // create stream config with _some_stream_id_, and benchmark using cpu timer
* stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false};
**/
struct
stream_config
struct
stream_config
{
{
hipStream_t
stream_id_
=
nullptr
;
hipStream_t
stream_id_
=
nullptr
;
...
@@ -13,5 +29,6 @@ struct stream_config
...
@@ -13,5 +29,6 @@ struct stream_config
int
log_level_
=
0
;
int
log_level_
=
0
;
int
cold_niters_
=
3
;
int
cold_niters_
=
3
;
int
nrepeat_
=
10
;
int
nrepeat_
=
10
;
bool
is_gpu_timer_
=
true
;
// keep compatible
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/host/timer.hpp
0 → 100644
View file @
88b978c5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
#include <chrono>
namespace
ck_tile
{
struct
gpu_timer
{
CK_TILE_HOST
gpu_timer
()
{
HIP_CHECK_ERROR
(
hipEventCreate
(
&
start_evt
));
HIP_CHECK_ERROR
(
hipEventCreate
(
&
stop_evt
));
}
CK_TILE_HOST
~
gpu_timer
()
noexcept
(
false
)
{
HIP_CHECK_ERROR
(
hipEventDestroy
(
start_evt
));
HIP_CHECK_ERROR
(
hipEventDestroy
(
stop_evt
));
}
CK_TILE_HOST
void
start
(
const
hipStream_t
&
s
)
{
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
HIP_CHECK_ERROR
(
hipEventRecord
(
start_evt
,
s
));
}
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
s
)
{
HIP_CHECK_ERROR
(
hipEventRecord
(
stop_evt
,
s
));
HIP_CHECK_ERROR
(
hipEventSynchronize
(
stop_evt
));
}
// return in ms
CK_TILE_HOST
float
duration
()
const
{
float
ms
=
0
;
HIP_CHECK_ERROR
(
hipEventElapsedTime
(
&
ms
,
start_evt
,
stop_evt
));
return
ms
;
}
private:
hipEvent_t
start_evt
,
stop_evt
;
};
struct
cpu_timer
{
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
start
(
const
hipStream_t
&
)
{
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
start_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
)
{
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
stop_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
// return in ms
CK_TILE_HOST
float
duration
()
const
{
double
sec
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
duration
<
double
>>
(
stop_tick
-
start_tick
)
.
count
();
return
static_cast
<
float
>
(
sec
*
1e3
);
}
private:
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
start_tick
;
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
stop_tick
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
View file @
88b978c5
...
@@ -23,13 +23,13 @@ VERTICAL:
...
@@ -23,13 +23,13 @@ VERTICAL:
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
TOP_LEFT:
TOP_LEFT
(but negative)
:
[0] 1 2 3 4 5
[0] 1 2 3 4 5
1 [0] 1 2 3 4
1 [0] 1 2 3 4
2 1 [0] 1 2 3
2 1 [0] 1 2 3
3 2 1 [0] 1 2
3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT:
FROM_BOTTOM_RIGHT
(but negative)
:
2 1 [0] 1 2 3
2 1 [0] 1 2 3
3 2 1 [0] 1 2
3 2 1 [0] 1 2
4 3 2 1 [0] 1
4 3 2 1 [0] 1
...
@@ -54,7 +54,7 @@ struct Alibi
...
@@ -54,7 +54,7 @@ struct Alibi
index_t
x_total_
,
index_t
x_total_
,
AlibiMode
mode_
=
AlibiMode
::
VERTICAL
)
AlibiMode
mode_
=
AlibiMode
::
VERTICAL
)
{
{
slope
=
mode_
==
AlibiMode
::
VERTICAL
?
slope_
:
-
slope
;
slope
=
mode_
==
AlibiMode
::
VERTICAL
?
slope_
:
-
slope
_
;
shift_left_up
=
[
&
]()
{
shift_left_up
=
[
&
]()
{
if
(
RowMajor
)
if
(
RowMajor
)
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
88b978c5
...
@@ -76,7 +76,7 @@ struct FmhaFwdKernel
...
@@ -76,7 +76,7 @@ struct FmhaFwdKernel
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
return
_SS_
(
"fmha_fwd_d"
)
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
_SS_
(
"fmha_fwd_d"
)
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
_SS_
(
TilePartitioner
::
name
)
+
"_"
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
...
@@ -702,7 +702,7 @@ struct FmhaFwdKernel
...
@@ -702,7 +702,7 @@ struct FmhaFwdKernel
else
else
{
{
return
Alibi
<
SaccDataType
,
true
>
{
return
Alibi
<
SaccDataType
,
true
>
{
slope
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
AlibiMode
::
VERTICAL
};
slope
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
AlibiMode
::
FROM_BOTTOM_RIGHT
};
}
}
}
}
else
else
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
View file @
88b978c5
...
@@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner
...
@@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
static
constexpr
const
char
*
name
=
"shb"
;
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
hdim_v_
)
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
{
// TODO: this may need tuning
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
...
@@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner
...
@@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner
}
}
};
};
template
<
typename
BlockFmhaShape_
>
using
FmhaFwdTilePartitioner_SHB
=
FmhaFwdTilePartitioner
<
BlockFmhaShape_
>
;
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdTilePartitioner_HBS
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
const
char
*
name
=
"hbs"
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
nhead_
,
batch_size_
,
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
kN1
));
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
x
;
const
index_t
i_batch
=
blockIdx
.
y
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
}
// namespace ck_tile
library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt
View file @
88b978c5
...
@@ -2,9 +2,14 @@
...
@@ -2,9 +2,14 @@
set
(
GEMM_MULTI_ABD_INSTANCES
)
set
(
GEMM_MULTI_ABD_INSTANCES
)
list
(
APPEND GEMM_MULTI_ABD_INSTANCES
list
(
APPEND GEMM_MULTI_ABD_INSTANCES
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
)
)
...
...
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
0 → 100644
View file @
88b978c5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
EDataType
,
AElementOp
,
Multiply
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
PassThrough
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
PassThrough
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
0 → 100644
View file @
88b978c5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
AElementOp
,
Multiply
,
Add
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<
D0Layout
>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<
D0DataType
>
,
Multiply
,
Add
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<
D0Layout
>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<
D0DataType
>
,
Multiply
,
Add
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
View file @
88b978c5
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <cstdlib>
...
@@ -52,112 +52,6 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
...
@@ -52,112 +52,6 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
Interwave
>
{});
Interwave
>
{});
}
}
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
AElementOp
,
Multiply
,
Add
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<
D0Layout
>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<
D0DataType
>
,
Multiply
,
Add
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<
D0Layout
>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<
D0DataType
>
,
Multiply
,
Add
,
GemmMNKPadding
,
Interwave
>
{});
}
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
EDataType
,
AElementOp
,
Multiply
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
PassThrough
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
PassThrough
,
GemmMNKPadding
,
Interwave
>
{});
}
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
EDataType
,
AElementOp
,
Multiply
,
FastGelu
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
FastGelu
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
FastGelu
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
0 → 100644
View file @
88b978c5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
EDataType
,
AElementOp
,
Multiply
,
FastGelu
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
FastGelu
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
,
B1Layout
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
B0DataType
,
B1DataType
>
,
ck
::
Tuple
<>
,
Multiply
,
FastGelu
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
0 → 100644
View file @
88b978c5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
Multiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
Multiply
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
Multiply
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
0 → 100644
View file @
88b978c5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
MultiplyAdd
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
PassThrough
,
MultiplyAdd
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
PassThrough
,
MultiplyAdd
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
View file @
88b978c5
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <cstdlib>
...
@@ -52,111 +52,6 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_i
...
@@ -52,111 +52,6 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_i
Interwave
>
{});
Interwave
>
{});
}
}
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
MultiplyAdd
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
PassThrough
,
MultiplyAdd
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
PassThrough
,
MultiplyAdd
,
GemmMNKPadding
,
Interwave
>
{});
}
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
Multiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
Multiply
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
Multiply
,
GemmMNKPadding
,
Interwave
>
{});
}
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
MultiplyFastGelu
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
MultiplyFastGelu
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
MultiplyFastGelu
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
0 → 100644
View file @
88b978c5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
MultiplyFastGelu
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
MultiplyFastGelu
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
MultiplyFastGelu
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
pyproject.toml
0 → 100644
View file @
88b978c5
[build-system]
requires
=
[
"setuptools"
,
"setuptools-scm"
]
build-backend
=
"setuptools.build_meta"
[project]
name
=
"rocm-composable-kernel"
dynamic
=
["version"]
description
=
"Composable Kernel, performance-critical kernels for machine learning workloads"
readme
=
"README.md"
requires-python
=
">=3.8"
license
=
{
file
=
"LICENSE"
}
classifiers
=
[
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: MIT License"
,
"Operating System :: OS Independent"
,
]
dependencies
=
[]
[project.urls]
"Homepage"
=
"https://github.com/rocm/composable_kernel"
"Bug
Tracker"
=
"https://github.com/rocm/composable_kernel/issues"
[tool.setuptools]
packages
=
[
"ck4inductor"
,
"ck4inductor.include"
,
"ck4inductor.library"
]
[tool.setuptools.package-dir]
ck4inductor
=
"python/ck4inductor"
"ck4inductor.include"
=
"include"
"ck4inductor.library"
=
"library"
[tool.setuptools.package-data]
"ck4inductor.include"
=
["ck/**/*.hpp"]
"ck4inductor.library"
=
["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"]
[tool.setuptools.dynamic]
version
=
{
attr
=
"setuptools_scm.get_version"
}
python/ck4inductor/__init__.py
0 → 100644
View file @
88b978c5
python/ck4inductor/universal_gemm/gen_instances.py
0 → 100644
View file @
88b978c5
import
logging
import
os
import
subprocess
from
dataclasses
import
fields
,
replace
from
functools
import
lru_cache
,
partial
from
typing
import
List
from
..util
import
library_path
from
.op
import
CKGemmOperation
log
=
logging
.
getLogger
(
__name__
)
def
_ck_library_dir
():
gemm_instances_path
=
os
.
path
.
join
(
library_path
(),
"src"
,
"tensor_operation_instance"
,
"gpu"
,
"gemm_universal"
)
if
not
os
.
path
.
exists
(
gemm_instances_path
):
log
.
error
(
"CK library path %s does not exist"
,
gemm_instances_path
)
return
None
return
gemm_instances_path
def
parse_instances
(
str_instances
:
List
[
str
])
->
List
[
CKGemmOperation
]:
"""
Parse the lines containing Universal Gemm template instances into `CKGemmOperation` instances
"""
def
maybe_int
(
s
):
try
:
return
int
(
s
)
except
ValueError
:
return
s
op_instances
=
[]
for
line
in
str_instances
:
s_template_args
=
line
.
split
(
"DeviceGemm_Xdl_CShuffleV3"
)[
-
1
].
strip
(
"<>, "
)
template_args
=
[]
i_current
=
0
while
i_current
<
len
(
s_template_args
):
if
s_template_args
[
i_current
]
==
" "
:
# skip whitespace
i_current
+=
1
continue
elif
s_template_args
[
i_current
:
i_current
+
2
]
==
"S<"
:
# parse template S<Index...>
i_next
=
s_template_args
.
find
(
">"
,
i_current
)
template_args
.
append
(
tuple
(
map
(
int
,
s_template_args
[
i_current
+
2
:
i_next
].
split
(
","
)))
)
i_current
=
i_next
+
2
else
:
# all string attributes must be either type aliases or global constants in C++
i_next
=
s_template_args
.
find
(
","
,
i_current
)
template_args
.
append
(
maybe_int
(
s_template_args
[
i_current
:
i_next
if
i_next
!=
-
1
else
None
]
)
)
if
i_next
!=
-
1
:
i_current
=
i_next
+
1
if
i_next
==
-
1
:
break
# pad with `None`s for the fields which are not defined in the instance
new_instance
=
CKGemmOperation
(
*
template_args
,
# type: ignore[arg-type]
*
((
None
,)
*
(
len
(
fields
(
CKGemmOperation
))
-
len
(
template_args
))),
)
# the last 2 template parameters are optional
# if they are absent, substitute them with default values from Universal Gemm C++ template declaration
if
new_instance
.
a_compute_dtype
is
None
:
new_instance
.
a_compute_dtype
=
new_instance
.
c_element_dtype
if
new_instance
.
b_compute_dtype
is
None
:
new_instance
.
b_compute_dtype
=
new_instance
.
c_element_dtype
op_instances
.
append
(
new_instance
)
return
op_instances
def
default_instances
()
->
List
[
CKGemmOperation
]:
# fallback: known working op instance for problem size M=2240 K=256 N=2048
# all string attributes must be either type aliases or global constants in C++
return
[
CKGemmOperation
(
a_layout
=
"Row"
,
b_layout
=
"Row"
,
c_layout
=
"Row"
,
a_element_dtype
=
"F16"
,
b_element_dtype
=
"F16"
,
c_element_dtype
=
"F16"
,
a_compute_dtype
=
"F16"
,
b_compute_dtype
=
"F16"
,
acc_dtype
=
"F32"
,
c_shuffle_dtype
=
"F16"
,
a_elementwise_op
=
"PassThrough"
,
b_elementwise_op
=
"PassThrough"
,
c_elementwise_op
=
"PassThrough"
,
gemm_specialization
=
"GemmSpecialization::Default"
,
block_size
=
256
,
m_per_block
=
224
,
n_per_block
=
256
,
k_per_block
=
64
,
a_k1
=
8
,
b_k1
=
2
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
7
,
n_xdl_per_wave
=
8
,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
=
(
8
,
32
,
1
),
a_block_transfer_thread_cluster_arrange_order
=
(
1
,
0
,
2
),
a_block_transfer_src_access_order
=
(
1
,
0
,
2
),
a_block_transfer_src_vector_dim
=
2
,
a_block_transfer_src_scalar_per_vector
=
8
,
a_block_transfer_dst_scalar_per_vector_ak1
=
8
,
a_block_lds_extra_m
=
0
,
# type: ignore[arg-type]
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
=
(
8
,
32
,
1
),
b_block_transfer_thread_cluster_arrange_order
=
(
0
,
2
,
1
),
b_block_transfer_src_access_order
=
(
0
,
2
,
1
),
b_block_transfer_src_vector_dim
=
1
,
b_block_transfer_src_scalar_per_vector
=
8
,
b_block_transfer_dst_scalar_per_vector_bk1
=
2
,
b_block_lds_extra_n
=
0
,
# type: ignore[arg-type]
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
2
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v3"
,
)
]
@
lru_cache
(
None
)
def
gen_ops_library
()
->
List
[
CKGemmOperation
]:
"""
Parse the Universal Gemm instances defined in the composable kernel library folder.
"""
ck_library_dir
=
_ck_library_dir
()
if
not
ck_library_dir
:
return
[]
grep_result
=
subprocess
.
run
(
[
"grep"
,
"-inR"
,
"DeviceGemm_Xdl_CShuffleV3"
,
_ck_library_dir
(),
],
capture_output
=
True
,
text
=
True
,
)
op_instances
=
parse_instances
(
grep_result
.
stdout
.
strip
().
split
(
"
\n
"
))
log
.
debug
(
"ck instances from library: %d"
,
len
(
op_instances
))
schedulers
=
[
"BlockGemmPipelineScheduler::Intrawave"
,
"BlockGemmPipelineScheduler::Interwave"
,
]
gemm_specs
=
[
"GemmSpecialization::Default"
,
"GemmSpecialization::MPadding"
,
"GemmSpecialization::NPadding"
,
"GemmSpecialization::KPadding"
,
"GemmSpecialization::MNPadding"
,
"GemmSpecialization::MKPadding"
,
"GemmSpecialization::NKPadding"
,
"GemmSpecialization::MNKPadding"
,
]
# substitute templated args by looping through their domains
substitute_instances
=
[]
for
instance
in
op_instances
:
sub_scheduler
=
instance
.
block_gemm_pipeline_scheduler
==
"BlkGemmPipeSched"
sub_spec
=
instance
.
gemm_specialization
==
"GemmSpec"
schedulers_range
=
(
schedulers
if
sub_scheduler
else
[
instance
.
block_gemm_pipeline_scheduler
]
)
spec_range
=
gemm_specs
if
sub_spec
else
[
instance
.
gemm_specialization
]
for
scheduler
in
schedulers_range
:
for
spec
in
spec_range
:
substitute_instances
.
append
(
replace
(
instance
,
block_gemm_pipeline_scheduler
=
scheduler
,
gemm_specialization
=
spec
,
)
)
return
substitute_instances
@
lru_cache
(
None
)
def
gen_ops_preselected
()
->
List
[
CKGemmOperation
]:
"""
Manually selected (through benchmarking) F16/F16/F16 Row/Col/Row instances
"""
ck_gemm_f16_rcr
=
partial
(
CKGemmOperation
,
a_layout
=
"Row"
,
b_layout
=
"Col"
,
c_layout
=
"Row"
,
a_element_dtype
=
"F16"
,
b_element_dtype
=
"F16"
,
c_element_dtype
=
"F16"
,
acc_dtype
=
"F32"
,
c_shuffle_dtype
=
"F16"
,
a_elementwise_op
=
"PassThrough"
,
b_elementwise_op
=
"PassThrough"
,
c_elementwise_op
=
"PassThrough"
,
k_per_block
=
64
,
a_k1
=
8
,
b_k1
=
8
,
a_block_transfer_thread_cluster_arrange_order
=
(
1
,
0
,
2
),
a_block_transfer_src_access_order
=
(
1
,
0
,
2
),
a_block_transfer_src_vector_dim
=
2
,
a_block_transfer_src_scalar_per_vector
=
8
,
a_block_transfer_dst_scalar_per_vector_ak1
=
8
,
a_block_lds_extra_m
=
0
,
b_block_transfer_thread_cluster_arrange_order
=
(
1
,
0
,
2
),
b_block_transfer_src_access_order
=
(
1
,
0
,
2
),
b_block_transfer_src_vector_dim
=
2
,
b_block_transfer_src_scalar_per_vector
=
8
,
b_block_transfer_dst_scalar_per_vector_bk1
=
8
,
b_block_lds_extra_n
=
0
,
a_compute_dtype
=
"F16"
,
b_compute_dtype
=
"F16"
,
)
ck_gemm_f16_rcr_compute_friendly
=
partial
(
ck_gemm_f16_rcr
,
block_size
=
256
,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
=
(
8
,
32
,
1
),
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
=
(
8
,
32
,
1
),
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
)
ck_gemm_f16_rcr_memory_friendly
=
partial
(
ck_gemm_f16_rcr
,
block_size
=
128
,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
=
(
8
,
16
,
1
),
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
=
(
8
,
16
,
1
),
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Interwave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v2"
,
)
ck_gemm_f16_rcr_latency_friendly
=
partial
(
ck_gemm_f16_rcr
,
gemm_specialization
=
"GemmSpecialization::Default"
,
block_size
=
128
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
=
(
8
,
16
,
1
),
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
=
(
8
,
16
,
1
),
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v1"
,
)
return
[
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
224
,
n_per_block
=
256
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
7
,
n_xdl_per_wave
=
8
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
2
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v3"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v3"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v4"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v5"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v3"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v4"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v5"
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
16
,
n_per_block
=
32
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
16
,
n_per_block
=
32
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
16
,
n_per_block
=
64
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
2
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
32
,
n_per_block
=
64
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
32
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
32
,
n_per_block
=
16
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
32
,
n_per_block
=
16
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
64
,
n_per_block
=
16
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
2
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
64
,
1
,
2
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
64
,
n_per_block
=
32
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
128
,
n_per_block
=
32
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
2
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_latency_friendly
(
m_per_block
=
16
,
n_per_block
=
32
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
),
ck_gemm_f16_rcr_latency_friendly
(
m_per_block
=
32
,
n_per_block
=
16
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
),
]
if
__name__
==
"__main__"
:
print
(
gen_ops_library
())
python/ck4inductor/universal_gemm/op.py
0 → 100644
View file @
88b978c5
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Optional
,
Tuple
@
dataclass
class
CKGemmOperation
:
"""
A python dataclass storing the template parameters of a CK Universal Gemm template instance
"""
a_layout
:
str
b_layout
:
str
c_layout
:
str
a_element_dtype
:
str
b_element_dtype
:
str
c_element_dtype
:
str
acc_dtype
:
str
c_shuffle_dtype
:
str
a_elementwise_op
:
str
b_elementwise_op
:
str
c_elementwise_op
:
str
gemm_specialization
:
str
block_size
:
int
m_per_block
:
int
n_per_block
:
int
k_per_block
:
int
a_k1
:
int
b_k1
:
int
m_per_xdl
:
int
n_per_xdl
:
int
m_xdl_per_wave
:
int
n_xdl_per_wave
:
int
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_vector_dim
:
int
a_block_transfer_src_scalar_per_vector
:
int
a_block_transfer_dst_scalar_per_vector_ak1
:
int
a_block_lds_extra_m
:
bool
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_vector_dim
:
int
b_block_transfer_src_scalar_per_vector
:
int
b_block_transfer_dst_scalar_per_vector_bk1
:
int
b_block_lds_extra_n
:
bool
c_shuffle_m_xdl_per_wave_per_shuffle
:
int
c_shuffle_n_xdl_per_wave_per_shuffle
:
int
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
:
(
Tuple
[
int
,
int
,
int
,
int
]
)
c_shuffle_block_transfer_scalar_per_vector_n_per_block
:
int
block_gemm_pipeline_scheduler
:
str
block_gemm_pipeline_version
:
Optional
[
str
]
a_compute_dtype
:
Optional
[
str
]
b_compute_dtype
:
Optional
[
str
]
def
name
(
self
):
# cpp alias for template instance
return
f
"ck_devicegemm_xdl_shuffle_v3_
{
self
.
key_name
()
}
"
def
key_name
(
self
):
# TBD; must be unique per instance. Intended to use as dict key
return
"_"
.
join
(
[
"K"
+
field_name
.
replace
(
"_"
,
""
).
lower
()
+
"V"
+
(
"x"
.
join
(
map
(
str
,
iter
(
field_value
)))
if
isinstance
(
field_value
,
tuple
)
else
str
(
field_value
).
replace
(
":"
,
""
)
)
for
field_name
,
field_value
in
self
.
dict_items
()
]
)
def
dict_items
(
self
):
return
asdict
(
self
).
items
()
python/ck4inductor/util.py
0 → 100644
View file @
88b978c5
import
functools
import
os
@
functools
.
lru_cache
(
None
)
def
library_path
():
return
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'library'
)
test/position_embedding/position_embedding.cpp
View file @
88b978c5
...
@@ -131,74 +131,74 @@ int main()
...
@@ -131,74 +131,74 @@ int main()
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
});
0
,
1
,
2
,
3
,
4
,
5
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
4
,
5
,
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
3
,
-
4
,
-
5
,
1
,
0
,
1
,
2
,
3
,
4
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
-
4
,
2
,
1
,
0
,
1
,
2
,
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
3
,
2
,
1
,
0
,
1
,
2
});
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
3
,
1
,
0
,
1
,
2
,
-
1
,
0
,
-
1
,
-
2
,
2
,
1
,
0
,
1
,
-
2
,
-
1
,
0
,
-
1
,
3
,
2
,
1
,
0
,
-
3
,
-
2
,
-
1
,
0
,
4
,
3
,
2
,
1
,
-
4
,
-
3
,
-
2
,
-
1
,
5
,
4
,
3
,
2
});
-
5
,
-
4
,
-
3
,
-
2
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
1
,
0
,
1
,
-
1
,
0
,
-
1
,
2
,
1
,
0
});
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
1
,
0
,
1
,
2
,
3
,
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
3
,
2
,
1
,
0
,
1
,
2
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
4
,
3
,
2
,
1
,
0
,
1
,
-
4
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
5
,
4
,
3
,
2
,
1
,
0
});
-
5
,
-
4
,
-
3
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
3
,
4
,
5
,
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
-
2
,
-
3
,
-
4
,
-
5
,
1
,
2
,
3
,
4
,
-
1
,
-
2
,
-
3
,
-
4
,
0
,
1
,
2
,
3
,
0
,
-
1
,
-
2
,
-
3
,
1
,
0
,
1
,
2
,
-
1
,
0
,
-
1
,
-
2
,
2
,
1
,
0
,
1
,
-
2
,
-
1
,
0
,
-
1
,
3
,
2
,
1
,
0
});
-
3
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
1
,
2
,
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
-
1
,
-
2
,
1
,
0
,
1
,
-
1
,
0
,
-
1
,
2
,
1
,
0
});
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
VERTICAL
,
{
0
,
1
,
2
,
3
,
4
,
5
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
VERTICAL
,
{
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
});
0
,
1
,
2
,
3
,
4
,
5
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
4
,
5
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
3
,
-
4
,
-
5
,
1
,
0
,
1
,
2
,
3
,
4
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
-
4
,
2
,
1
,
0
,
1
,
2
,
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
3
,
2
,
1
,
0
,
1
,
2
});
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
3
,
1
,
0
,
1
,
2
,
-
1
,
0
,
-
1
,
-
2
,
2
,
1
,
0
,
1
,
-
2
,
-
1
,
0
,
-
1
,
3
,
2
,
1
,
0
,
-
3
,
-
2
,
-
1
,
0
,
4
,
3
,
2
,
1
,
-
4
,
-
3
,
-
2
,
-
1
,
5
,
4
,
3
,
2
});
-
5
,
-
4
,
-
3
,
-
2
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
1
,
0
,
1
,
-
1
,
0
,
-
1
,
2
,
1
,
0
});
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
1
,
0
,
1
,
2
,
3
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
3
,
2
,
1
,
0
,
1
,
2
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
4
,
3
,
2
,
1
,
0
,
1
,
-
4
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
5
,
4
,
3
,
2
,
1
,
0
});
-
5
,
-
4
,
-
3
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
3
,
4
,
5
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
-
2
,
-
3
,
-
4
,
-
5
,
1
,
2
,
3
,
4
,
-
1
,
-
2
,
-
3
,
-
4
,
0
,
1
,
2
,
3
,
0
,
-
1
,
-
2
,
-
3
,
1
,
0
,
1
,
2
,
-
1
,
0
,
-
1
,
-
2
,
2
,
1
,
0
,
1
,
-
2
,
-
1
,
0
,
-
1
,
3
,
2
,
1
,
0
});
-
3
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
1
,
2
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
-
1
,
-
2
,
1
,
0
,
1
,
-
1
,
0
,
-
1
,
2
,
1
,
0
});
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
8
,
{
0.5
,
0.25
,
0.125
,
0.0625
,
0.03125
,
0.015625
,
0.0078125
,
0.00390625
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
8
,
{
0.5
,
0.25
,
0.125
,
0.0625
,
0.03125
,
0.015625
,
0.0078125
,
0.00390625
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
16
,
{
0.7071067811865476
,
0.5
,
0.35355339059327384
,
0.25000000000000006
,
0.17677669529663692
,
rtn
&=
test_alibi_slope_generation
<
float
>
(
16
,
{
0.7071067811865476
,
0.5
,
0.35355339059327384
,
0.25000000000000006
,
0.17677669529663692
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment