Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9ef5ba27
You need to sign in or sign up before continuing.
Unverified
Commit
9ef5ba27
authored
Dec 02, 2024
by
rocking
Committed by
GitHub
Dec 02, 2024
Browse files
Merge branch 'develop' into ck_tile/pure_quant
parents
e9576baa
44828b7c
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1127 additions
and
10 deletions
+1127
-10
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+1
-1
example/ck_tile/16_batched_gemm/CMakeLists.txt
example/ck_tile/16_batched_gemm/CMakeLists.txt
+1
-0
example/ck_tile/16_batched_gemm/README.md
example/ck_tile/16_batched_gemm/README.md
+37
-0
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+103
-0
example/ck_tile/16_batched_gemm/batched_gemm.hpp
example/ck_tile/16_batched_gemm/batched_gemm.hpp
+63
-0
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+253
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+1
-1
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+112
-0
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-0
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
.../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
+3
-3
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+258
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+1
-1
python/ck4inductor/batched_universal_gemm/gen_instances.py
python/ck4inductor/batched_universal_gemm/gen_instances.py
+149
-0
python/ck4inductor/batched_universal_gemm/op.py
python/ck4inductor/batched_universal_gemm/op.py
+99
-0
python/ck4inductor/grouped_conv_fwd/gen_instances.py
python/ck4inductor/grouped_conv_fwd/gen_instances.py
+1
-3
test/ck_tile/CMakeLists.txt
test/ck_tile/CMakeLists.txt
+1
-0
test/ck_tile/batched_gemm/CMakeLists.txt
test/ck_tile/batched_gemm/CMakeLists.txt
+4
-0
test/ck_tile/batched_gemm/test_batched_gemm.cpp
test/ck_tile/batched_gemm/test_batched_gemm.cpp
+29
-0
test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc
test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc
+9
-0
No files found.
docs/sphinx/requirements.in
View file @
9ef5ba27
rocm-docs-core==1.9.
1
rocm-docs-core==1.9.
2
sphinxcontrib-bibtex==2.6.3
sphinxcontrib-bibtex==2.6.3
docs/sphinx/requirements.txt
View file @
9ef5ba27
...
@@ -103,7 +103,7 @@ requests==2.32.3
...
@@ -103,7 +103,7 @@ requests==2.32.3
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core==1.9.
1
rocm-docs-core==1.9.
2
# via -r requirements.in
# via -r requirements.in
six==1.16.0
six==1.16.0
# via pybtex
# via pybtex
...
...
example/ck_tile/16_batched_gemm/CMakeLists.txt
0 → 100644
View file @
9ef5ba27
add_executable
(
tile_example_batched_gemm EXCLUDE_FROM_ALL batched_gemm.cpp
)
example/ck_tile/16_batched_gemm/README.md
0 → 100644
View file @
9ef5ba27
# Batched GEMM
This folder contains example for batched GEMM using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_batched_gemm -j
```
This will result in an executable
`build/bin/tile_example_batched_gemm`
## example
```
args:
-m m dimension (default:256)
-n n dimension (default:128)
-k k dimension (default:128)
-a_layout A tensor data layout (default:R) (R for Row, C for Col)
-b_layout B tensor data layout (default:R) (R for Row, C for Col)
-c_layout C tensor data layout (default:R) (R for Row, C for Col)
-stride_a Tensor A stride (default:128)
-stride_b Tensor B stride (default:128)
-stride_c Tensor C stride (default:128)
-batch_stride_a Batch A stride (default:32768)
-batch_stride_b Batch B stride (default:16384)
-batch_stride_c Batch C stride (default:32768)
-batch_count Batch count (default:16)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
```
\ No newline at end of file
example/ck_tile/16_batched_gemm/batched_gemm.cpp
0 → 100644
View file @
9ef5ba27
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "batched_gemm.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
batched_gemm
(
const
batched_gemm_kargs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
#include "run_batched_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_batched_gemm_example
(
argc
,
argv
);
}
example/ck_tile/16_batched_gemm/batched_gemm.hpp
0 → 100644
View file @
9ef5ba27
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
template
<
typename
DataType
>
struct
BatchedGemmTypeConfig
;
template
<
>
struct
BatchedGemmTypeConfig
<
ck_tile
::
half_t
>
{
using
ADataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
half_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
};
using
Types
=
BatchedGemmTypeConfig
<
ck_tile
::
half_t
>
;
// Specific type aliases for easy access
using
ADataType
=
Types
::
ADataType
;
using
BDataType
=
Types
::
BDataType
;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmHostArgs
{
};
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"256"
,
"m dimension"
)
.
insert
(
"n"
,
"128"
,
"n dimension"
)
.
insert
(
"k"
,
"128"
,
"k dimension"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"R"
,
"B tensor data layout - Row by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"batch_stride_a"
,
"32768"
,
"Batch A stride"
)
.
insert
(
"batch_stride_b"
,
"16384"
,
"Batch B stride"
)
.
insert
(
"batch_stride_c"
,
"32768"
,
"Batch C stride"
)
.
insert
(
"batch_count"
,
"16"
,
"Batch count"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
// host API
float
batched_gemm
(
batched_gemm_kargs
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
0 → 100644
View file @
9ef5ba27
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_batched_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
K
,
ck_tile
::
index_t
stride_A
,
ck_tile
::
index_t
stride_B
,
ck_tile
::
index_t
stride_C
,
ck_tile
::
index_t
batch_stride_A
,
ck_tile
::
index_t
batch_stride_B
,
ck_tile
::
index_t
batch_stride_C
,
ck_tile
::
index_t
batch_count
,
int
n_warmup
,
int
n_repeat
)
{
batched_gemm_kargs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
stride_A
=
stride_A
;
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
args
.
batch_stride_A
=
batch_stride_A
;
args
.
batch_stride_B
=
batch_stride_B
;
args
.
batch_stride_C
=
batch_stride_C
;
args
.
batch_count
=
batch_count
;
float
ave_time
=
batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
string
op_name
{
"Batched Gemm"
};
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
batch_count
*
M
*
N
*
K
;
std
::
size_t
num_byte
=
sizeof
(
ADataType
)
*
batch_count
*
M
*
K
+
sizeof
(
BDataType
)
*
batch_count
*
N
*
K
+
sizeof
(
CDataType
)
*
batch_count
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"Run "
<<
op_name
<<
"kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" batch_stride_A ="
<<
batch_stride_A
<<
" batch_stride_B ="
<<
batch_stride_B
<<
" batch_stride_C ="
<<
batch_stride_C
<<
" batch_count ="
<<
batch_count
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
return
ave_time
;
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
int
run_batched_gemm_example_with_layouts
(
int
argc
,
char
*
argv
[],
const
ALayout
a_layout
=
ALayout
{},
const
BLayout
b_layout
=
BLayout
{},
[[
maybe_unused
]]
const
CLayout
c_layout
=
CLayout
{})
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
stride_A
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
ck_tile
::
index_t
batch_stride_A
=
arg_parser
.
get_int
(
"batch_stride_a"
);
ck_tile
::
index_t
batch_stride_B
=
arg_parser
.
get_int
(
"batch_stride_b"
);
ck_tile
::
index_t
batch_stride_C
=
arg_parser
.
get_int
(
"batch_stride_c"
);
ck_tile
::
index_t
batch_count
=
arg_parser
.
get_int
(
"batch_count"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
using
namespace
ck_tile
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count_
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
stride
,
1_
uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
1_
uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
stride_A
=
f_get_default_stride
(
M
,
K
,
stride_A
,
a_layout
);
stride_B
=
f_get_default_stride
(
K
,
N
,
stride_B
,
b_layout
);
stride_C
=
f_get_default_stride
(
M
,
N
,
stride_C
,
c_layout
);
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
batch_stride_A
,
a_layout
));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
batch_stride_B
,
b_layout
));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
c_layout
));
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_m_n_dev_buf
(
c_m_n_dev_result
.
get_element_space_size_in_bytes
());
a_m_k_dev_buf
.
ToDevice
(
a_m_k
.
data
());
b_k_n_dev_buf
.
ToDevice
(
b_k_n
.
data
());
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
invoke_batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
batch_stride_A
,
batch_stride_B
,
batch_stride_C
,
batch_count
,
n_warmup
,
n_repeat
);
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
CLayout
{}));
c_m_n_host_ref
.
SetZero
();
const
auto
b_n_k
=
b_k_n
.
transpose
({
0
,
2
,
1
});
ck_tile
::
reference_batched_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_n_k
,
c_m_n_host_ref
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
);
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
CLayout
{}));
ck_tile
::
DeviceMem
c_m_n_gpu_buf_ref
(
c_m_n_gpu_ref
.
get_element_space_size_in_bytes
());
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
ck_tile
::
reference_batched_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_gpu_buf_ref
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
batch_stride_A
,
batch_stride_B
,
batch_stride_C
,
batch_count
);
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
std
::
cout
<<
"The GPU verification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
return
pass
;
}
int
run_batched_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
return
run_batched_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
return
run_batched_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work else if(a_layout == "C" && b_layout == "C")
// {
// return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// }
// else if(a_layout == "C" && b_layout == "R")
// {
// return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
}
}
example/ck_tile/CMakeLists.txt
View file @
9ef5ba27
...
@@ -15,4 +15,4 @@ add_subdirectory(12_smoothquant)
...
@@ -15,4 +15,4 @@ add_subdirectory(12_smoothquant)
add_subdirectory
(
13_moe_sorting
)
add_subdirectory
(
13_moe_sorting
)
add_subdirectory
(
14_moe_smoothquant
)
add_subdirectory
(
14_moe_smoothquant
)
add_subdirectory
(
15_fused_moe
)
add_subdirectory
(
15_fused_moe
)
add_subdirectory
(
16_batched_gemm
)
include/ck_tile/host/reference/reference_gemm.hpp
View file @
9ef5ba27
...
@@ -183,4 +183,116 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -183,4 +183,116 @@ void reference_gemm_gpu(DeviceMem& a_device,
return
;
return
;
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
void
reference_batched_gemm_gpu
(
DeviceMem
&
a_device
,
DeviceMem
&
b_device
,
DeviceMem
&
c_device
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_a
,
index_t
stride_b
,
index_t
stride_c
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
index_t
batch_stride_C
,
index_t
batch_count
)
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
{
ADataType
*
d_ATemp
=
d_A
+
batch_id
*
batch_stride_A
;
BDataType
*
d_BTemp
=
d_B
+
batch_id
*
batch_stride_B
;
CDataType
*
d_CTemp
=
d_C
+
batch_id
*
batch_stride_C
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
}
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
View file @
9ef5ba27
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
...
...
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
View file @
9ef5ba27
...
@@ -623,7 +623,7 @@ struct BlockUniversalGemmAsBsCr
...
@@ -623,7 +623,7 @@ struct BlockUniversalGemmAsBsCr
CK_TILE_DEVICE
void
LocalPrefetch
(
const
ASmemBlockWindow
&
a_block_window
,
CK_TILE_DEVICE
void
LocalPrefetch
(
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
const
BSmemBlockWindow
&
b_block_window
)
{
{
block_gemm_impl_
.
template
LocalPrefetch
(
a_block_window
,
b_block_window
);
block_gemm_impl_
.
LocalPrefetch
(
a_block_window
,
b_block_window
);
}
}
// C += A * B
// C += A * B
...
@@ -632,7 +632,7 @@ struct BlockUniversalGemmAsBsCr
...
@@ -632,7 +632,7 @@ struct BlockUniversalGemmAsBsCr
const
ASmemBlockWindow
&
a_block_window
,
const
ASmemBlockWindow
&
a_block_window
,
const
BSmemBlockWindow
&
b_block_window
)
const
BSmemBlockWindow
&
b_block_window
)
{
{
block_gemm_impl_
.
template
operator
()
(
c_block_tensor
,
a_block_window
,
b_block_window
);
block_gemm_impl_
(
c_block_tensor
,
a_block_window
,
b_block_window
);
}
}
// C = A * B
// C = A * B
...
@@ -641,7 +641,7 @@ struct BlockUniversalGemmAsBsCr
...
@@ -641,7 +641,7 @@ struct BlockUniversalGemmAsBsCr
const
BSmemBlockWindow
&
b_block_window
)
const
BSmemBlockWindow
&
b_block_window
)
{
{
auto
c_block_tensor
=
MakeCBlockTile
();
auto
c_block_tensor
=
MakeCBlockTile
();
block_gemm_impl_
.
template
operator
()
(
c_block_tensor
,
a_block_window
,
b_block_window
);
block_gemm_impl_
(
c_block_tensor
,
a_block_window
,
b_block_window
);
return
c_block_tensor
;
return
c_block_tensor
;
}
}
...
...
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
0 → 100644
View file @
9ef5ba27
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
struct
BatchedGemmHostArgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
batch_stride_A
;
index_t
batch_stride_B
;
index_t
batch_stride_C
;
index_t
batch_count
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
BatchedGemmKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmPipeline
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
CLayout
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
struct
BatchedGemmKargs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
batch_stride_A
;
index_t
batch_stride_B
;
index_t
batch_stride_C
;
index_t
batch_count
;
};
using
Kargs
=
BatchedGemmKargs
;
using
Hargs
=
BatchedGemmHostArgs
;
__host__
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
return
TilePartitioner
::
GridSize
(
h
.
M
,
h
.
N
,
h
.
batch_count
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
CK_TILE_HOST
static
constexpr
BatchedGemmKargs
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
a_ptr
=
h
.
a_ptr
;
k
.
b_ptr
=
h
.
b_ptr
;
k
.
c_ptr
=
h
.
c_ptr
;
k
.
M
=
h
.
M
;
k
.
N
=
h
.
N
;
k
.
K
=
h
.
K
;
k
.
stride_A
=
h
.
stride_A
;
k
.
stride_B
=
h
.
stride_B
;
k
.
stride_C
=
h
.
stride_C
;
k
.
batch_stride_A
=
h
.
batch_stride_A
;
k
.
batch_stride_B
=
h
.
batch_stride_B
;
k
.
batch_stride_C
=
h
.
batch_stride_C
;
k
.
batch_count
=
h
.
batch_count
;
return
k
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
// options
const
auto
batch_stride_A
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_A
);
const
auto
batch_offset_A
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_A
);
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
auto
batch_stride_B
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_B
);
const
auto
batch_offset_B
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_B
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
+
batch_offset_A
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
+
batch_offset_A
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
+
batch_offset_B
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
+
batch_offset_B
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
1
>
{});
}
}();
auto
a_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
auto
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
// clang-format on
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
const
auto
batch_stride_C
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_C
);
const
auto
batch_offset_C
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_C
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
+
batch_offset_C
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
+
batch_offset_C
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
c_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
c_block_window
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
9ef5ba27
...
@@ -124,7 +124,7 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -124,7 +124,7 @@ struct GemmPipelineAGmemBGmemCRegV1
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
// Block GEMM
constexpr
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
// Acc register tile
// Acc register tile
auto
c_block_tile
=
decltype
(
block_gemm
(
a_lds_gemm_window
,
b_lds_gemm_window
)){};
auto
c_block_tile
=
decltype
(
block_gemm
(
a_lds_gemm_window
,
b_lds_gemm_window
)){};
...
...
python/ck4inductor/batched_universal_gemm/gen_instances.py
0 → 100644
View file @
9ef5ba27
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import
logging
import
os
import
subprocess
from
dataclasses
import
replace
from
functools
import
lru_cache
from
typing
import
List
from
..util
import
library_path
from
.op
import
CKBatchedGemmOperation
log
=
logging
.
getLogger
(
__name__
)
def
_ck_library_dir
():
gemm_instances_path
=
os
.
path
.
join
(
library_path
(),
"src"
,
"tensor_operation_instance"
,
"gpu"
,
"gemm_universal_batched"
,
)
if
not
os
.
path
.
exists
(
gemm_instances_path
):
log
.
error
(
"CK library path %s does not exist"
,
gemm_instances_path
)
return
None
return
gemm_instances_path
def
parse_instances
(
str_instances
:
List
[
str
])
->
List
[
CKBatchedGemmOperation
]:
"""
Parse the lines containing Universal Gemm template instances into `CKBatchedGemmOperation` instances
"""
def
maybe_int
(
s
):
try
:
return
int
(
s
)
except
ValueError
:
return
s
op_instances
=
[]
for
line
in
str_instances
:
s_template_args
=
line
.
split
(
"DeviceBatchedGemmMultiD_Xdl_CShuffle_V3"
)[
-
1
].
strip
(
"<>, "
)
template_args
=
[]
i_current
=
0
while
i_current
<
len
(
s_template_args
):
if
s_template_args
[
i_current
]
==
" "
:
# skip whitespace
i_current
+=
1
continue
elif
s_template_args
[
i_current
:
i_current
+
2
]
==
"S<"
:
# parse template S<Index...>
i_next
=
s_template_args
.
find
(
">"
,
i_current
)
template_args
.
append
(
tuple
(
map
(
int
,
s_template_args
[
i_current
+
2
:
i_next
].
split
(
","
)))
)
i_current
=
i_next
+
2
else
:
# all string attributes must be either type aliases or global constants in C++
i_next
=
s_template_args
.
find
(
","
,
i_current
)
template_args
.
append
(
maybe_int
(
s_template_args
[
i_current
:
i_next
if
i_next
!=
-
1
else
None
]
)
)
if
i_next
!=
-
1
:
i_current
=
i_next
+
1
if
i_next
==
-
1
:
break
# ds layout and dtype are parsed as placeholder; reset value
template_args
[
2
]
=
tuple
()
# ds layout
template_args
[
6
]
=
tuple
()
# ds dtype
new_instance
=
CKBatchedGemmOperation
(
*
template_args
,
# type: ignore[arg-type]
)
op_instances
.
append
(
new_instance
)
return
op_instances
@
lru_cache
(
None
)
def
gen_ops_library
()
->
List
[
CKBatchedGemmOperation
]:
"""
Parse the Universal Gemm instances defined in the composable kernel library folder.
"""
ck_library_dir
=
_ck_library_dir
()
if
not
ck_library_dir
:
return
[]
grep_result
=
subprocess
.
run
(
[
"grep"
,
"-inR"
,
"DeviceBatchedGemmMultiD_Xdl_CShuffle_V3"
,
_ck_library_dir
(),
],
capture_output
=
True
,
text
=
True
,
)
op_instances
=
parse_instances
(
grep_result
.
stdout
.
strip
().
split
(
"
\n
"
))
log
.
debug
(
"ck instances from library: %d"
,
len
(
op_instances
))
schedulers
=
[
"BlockGemmPipelineScheduler::Intrawave"
,
"BlockGemmPipelineScheduler::Interwave"
,
]
gemm_specs
=
[
"GemmSpecialization::Default"
,
"GemmSpecialization::MPadding"
,
"GemmSpecialization::NPadding"
,
"GemmSpecialization::KPadding"
,
"GemmSpecialization::MNPadding"
,
"GemmSpecialization::MKPadding"
,
"GemmSpecialization::NKPadding"
,
"GemmSpecialization::MNKPadding"
,
]
# substitute templated args by looping through their domains
substitute_instances
=
[]
for
instance
in
op_instances
:
sub_scheduler
=
instance
.
block_gemm_pipeline_scheduler
==
"BlkGemmPipeSched"
sub_spec
=
instance
.
gemm_specialization
==
"GemmSpec"
schedulers_range
=
(
schedulers
if
sub_scheduler
else
[
instance
.
block_gemm_pipeline_scheduler
]
)
spec_range
=
gemm_specs
if
sub_spec
else
[
instance
.
gemm_specialization
]
for
scheduler
in
schedulers_range
:
for
spec
in
spec_range
:
substitute_instances
.
append
(
replace
(
instance
,
block_gemm_pipeline_scheduler
=
scheduler
,
gemm_specialization
=
spec
,
)
)
return
substitute_instances
if
__name__
==
"__main__"
:
print
(
gen_ops_library
())
python/ck4inductor/batched_universal_gemm/op.py
0 → 100644
View file @
9ef5ba27
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Optional
,
Tuple
@
dataclass
class
CKBatchedGemmOperation
:
"""
A python dataclass storing the template parameters of a CK Universal Gemm template instance
"""
a_layout
:
str
b_layout
:
str
ds_layouts
:
Tuple
[
str
]
# addmm specific
c_layout
:
str
a_element_dtype
:
str
b_element_dtype
:
str
ds_element_dtypes
:
Tuple
[
str
]
# addmm specific
c_element_dtype
:
str
acc_dtype
:
str
c_shuffle_dtype
:
str
a_elementwise_op
:
str
b_elementwise_op
:
str
c_elementwise_op
:
str
gemm_specialization
:
str
block_size
:
int
m_per_block
:
int
n_per_block
:
int
k_per_block
:
int
a_k1
:
int
b_k1
:
int
m_per_xdl
:
int
n_per_xdl
:
int
m_xdl_per_wave
:
int
n_xdl_per_wave
:
int
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_vector_dim
:
int
a_block_transfer_src_scalar_per_vector
:
int
a_block_transfer_dst_scalar_per_vector_ak1
:
int
a_block_lds_extra_m
:
bool
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_vector_dim
:
int
b_block_transfer_src_scalar_per_vector
:
int
b_block_transfer_dst_scalar_per_vector_bk1
:
int
b_block_lds_extra_n
:
bool
c_shuffle_m_xdl_per_wave_per_shuffle
:
int
c_shuffle_n_xdl_per_wave_per_shuffle
:
int
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
:
(
Tuple
[
int
,
int
,
int
,
int
]
)
c_shuffle_block_transfer_scalar_per_vector_n_per_block
:
Tuple
[
int
]
block_gemm_pipeline_scheduler
:
str
block_gemm_pipeline_version
:
str
a_compute_dtype
:
Optional
[
str
]
=
None
b_compute_dtype
:
Optional
[
str
]
=
None
def
name
(
self
):
# cpp alias for template instance
return
f
"ck_device_batched_gemm_multi_d_xdl_c_shuffle_v3_
{
self
.
key_name
()
}
"
def
key_name
(
self
):
# TBD; must be unique per instance. Intended to use as dict key
return
"_"
.
join
(
[
"K"
+
field_name
.
replace
(
"_"
,
""
).
lower
()
+
"V"
+
(
"x"
.
join
(
map
(
str
,
iter
(
field_value
)))
if
isinstance
(
field_value
,
tuple
)
else
str
(
field_value
).
replace
(
":"
,
""
)
)
for
field_name
,
field_value
in
self
.
dict_items
()
]
)
def
dict_items
(
self
):
return
asdict
(
self
).
items
()
python/ck4inductor/grouped_conv_fwd/gen_instances.py
View file @
9ef5ba27
...
@@ -130,9 +130,7 @@ def gen_conv_ops_library() -> List[CKGroupedConvFwdOp]:
...
@@ -130,9 +130,7 @@ def gen_conv_ops_library() -> List[CKGroupedConvFwdOp]:
# substitute templated args by looping through their domains
# substitute templated args by looping through their domains
substitute_instances
=
[]
substitute_instances
=
[]
for
instance
in
op_instances
:
for
instance
in
op_instances
:
sub_scheduler
=
(
sub_scheduler
=
instance
.
block_gemm_pipeline_scheduler
==
"BlkGemmPipeSched"
instance
.
block_gemm_pipeline_scheduler
==
"BlkGemmPipeSched"
)
sub_spec
=
instance
.
conv_forward_specialization
==
"ConvSpec"
sub_spec
=
instance
.
conv_forward_specialization
==
"ConvSpec"
schedulers_range
=
(
schedulers_range
=
(
schedulers
if
sub_scheduler
else
[
instance
.
block_gemm_pipeline_scheduler
]
schedulers
if
sub_scheduler
else
[
instance
.
block_gemm_pipeline_scheduler
]
...
...
test/ck_tile/CMakeLists.txt
View file @
9ef5ba27
add_subdirectory
(
image_to_column
)
add_subdirectory
(
image_to_column
)
add_subdirectory
(
gemm
)
add_subdirectory
(
gemm
)
add_subdirectory
(
batched_gemm
)
test/ck_tile/batched_gemm/CMakeLists.txt
0 → 100644
View file @
9ef5ba27
# Currently ck_tile is only built on gfx9
if
(
GPU_TARGETS MATCHES
"gfx9"
)
add_gtest_executable
(
test_ck_tile_batched_gemm test_batched_gemm.cpp
)
endif
()
test/ck_tile/batched_gemm/test_batched_gemm.cpp
0 → 100644
View file @
9ef5ba27
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_batched_gemm_util.hpp"
using
F16
=
ck_tile
::
half_t
;
using
F32
=
float
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestCkTileBatchedGemm
,
KernelTypes
);
#include "test_batched_gemm_ut_cases.inc"
test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc
0 → 100644
View file @
9ef5ba27
#pragma once
TYPED_TEST
(
TestCkTileBatchedGemm
,
Basic
)
{
constexpr
int
M
=
256
;
constexpr
int
N
=
128
;
constexpr
int
K
=
128
;
this
->
Run
(
M
,
N
,
K
);
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment