Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8bd49370
Commit
8bd49370
authored
Oct 01, 2024
by
Adam Osewski
Browse files
Refactoring & Move Layout info to pipeline problem.
parent
d3689b06
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
469 additions
and
841 deletions
+469
-841
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+97
-142
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+2
-5
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
+253
-207
include/ck_tile/core/utility/literals.hpp
include/ck_tile/core/utility/literals.hpp
+22
-0
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+19
-30
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+42
-47
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_mem.hpp
...le/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_mem.hpp
+0
-394
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
...ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
+23
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_default_policy.hpp
...s/gemm/pipeline/gemm_pipeline_ag_bg_cr_default_policy.hpp
+3
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp
...emm/pipeline/gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp
+5
-7
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
...le/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
+3
-1
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
8bd49370
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <cstring>
#include <cstring>
...
@@ -11,6 +10,11 @@
...
@@ -11,6 +10,11 @@
#include <string>
#include <string>
#include <tuple>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
...
@@ -22,7 +26,6 @@ auto create_args(int argc, char* argv[])
...
@@ -22,7 +26,6 @@ auto create_args(int argc, char* argv[])
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"e"
,
"1e-5"
,
"Absolute error tolerance"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"10"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"warmup"
,
"10"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
...
@@ -51,13 +54,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -51,13 +54,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
,
LayoutA
,
LayoutB
,
LayoutC
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_b
,
args
.
p_c
,
args
.
p_c
,
args
.
epsilon
,
args
.
M
,
args
.
M
,
args
.
N
,
args
.
N
,
args
.
K
,
args
.
K
,
...
@@ -96,7 +97,6 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
...
@@ -96,7 +97,6 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
return
-
1
;
// Or handle the error appropriately
return
-
1
;
// Or handle the error appropriately
}
}
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
ck_tile
::
index_t
batch_size
=
arg_parser
.
get_int
(
"b"
);
ck_tile
::
index_t
batch_size
=
arg_parser
.
get_int
(
"b"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
...
@@ -107,69 +107,37 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
...
@@ -107,69 +107,37 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
ck_tile
::
index_t
stride_c
=
arg_parser
.
get_int
(
"stride_c"
);
ck_tile
::
index_t
stride_c
=
arg_parser
.
get_int
(
"stride_c"
);
gemm_basic_args
args
;
gemm_basic_args
args
;
args
.
p_a
=
a_buf
.
GetDeviceBuffer
();
args
.
p_a
=
a_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_buf
.
GetDeviceBuffer
();
args
.
epsilon
=
epsilon
;
args
.
kbatch
=
batch_size
;
args
.
kbatch
=
batch_size
;
args
.
M
=
M
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
K
=
K
;
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
// Only set stride_M and stride_N if they are non-zero and not equal to K.
std
::
size_t
col
,
if
(
stride_a
!=
0
)
std
::
size_t
stride
,
{
auto
layout
)
{
args
.
stride_A
=
stride_a
;
if
(
stride
==
0
)
}
{
else
// give a chance if stride is zero, return a default packed stride
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
args
.
stride_A
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
LayoutA
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
M
;
}
else
{
return
K
;
}
}();
}
if
(
stride_b
!=
0
)
{
args
.
stride_B
=
stride_b
;
}
else
{
args
.
stride_B
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
LayoutB
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
{
return
N
;
return
col
;
}
}
else
else
{
{
return
K
;
return
row
;
}
}
}();
}
}
else
return
stride
;
};
if
(
stride_c
!=
0
)
args
.
stride_A
=
f_get_default_stride
(
M
,
K
,
stride_a
,
LayoutA
{});
{
args
.
stride_B
=
f_get_default_stride
(
K
,
N
,
stride_b
,
LayoutB
{});
args
.
stride_C
=
stride_c
;
args
.
stride_C
=
f_get_default_stride
(
M
,
N
,
stride_c
,
LayoutC
{});
}
else
{
args
.
stride_C
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
LayoutC
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
M
;
}
else
{
return
N
;
}
}();
}
float
ave_time
=
gemm_calc
<
LayoutA
,
LayoutB
,
LayoutC
,
PipelineProblem
,
GemmPipeline
,
GemmShape
>
(
float
ave_time
=
gemm_calc
<
LayoutA
,
LayoutB
,
LayoutC
,
PipelineProblem
,
GemmPipeline
,
GemmShape
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
});
args
,
ck_tile
::
stream_config
{
nullptr
,
true
});
...
@@ -197,30 +165,57 @@ int main(int argc, char* argv[])
...
@@ -197,30 +165,57 @@ int main(int argc, char* argv[])
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N).
ck_tile
::
index_t
stride_A
=
arg_parser
.
get_int
(
"stride_a"
);
using
matrix_a_layout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
using
matrix_b_layout
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
using
matrix_c_layout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
ALayout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
// host verify
using
BLayout
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
vector
<
int
>
a_dimensions
=
using
CLayout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
(
std
::
is_same_v
<
matrix_a_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
?
std
::
vector
<
int
>
{
M
,
K
}
using
namespace
ck_tile
::
literals
;
:
std
::
vector
<
int
>
{
K
,
M
};
std
::
vector
<
int
>
b_dimensions
=
auto
f_host_tensor_descriptor
=
(
std
::
is_same_v
<
matrix_b_layout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
?
std
::
vector
<
int
>
{
N
,
K
}
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
:
std
::
vector
<
int
>
{
K
,
N
};
{
std
::
vector
<
int
>
c_dimensions
=
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
(
std
::
is_same_v
<
matrix_c_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
}
?
std
::
vector
<
int
>
{
M
,
N
}
else
:
std
::
vector
<
int
>
{
N
,
M
};
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
ck_tile
::
HostTensor
<
ADataType
>
a_host
(
a_dimensions
);
}
ck_tile
::
HostTensor
<
BDataType
>
b_host
(
b_dimensions
);
};
ck_tile
::
HostTensor
<
CDataType
>
c_host_ref
(
c_dimensions
);
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
ck_tile
::
HostTensor
<
CDataType
>
c_host_dev
(
c_dimensions
);
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
stride_A
=
f_get_default_stride
(
M
,
K
,
stride_A
,
ALayout
{});
stride_B
=
f_get_default_stride
(
K
,
N
,
stride_B
,
BLayout
{});
stride_C
=
f_get_default_stride
(
M
,
N
,
stride_C
,
CLayout
{});
ck_tile
::
HostTensor
<
ADataType
>
a_host
(
f_host_tensor_descriptor
(
M
,
K
,
stride_A
,
ALayout
{}));
ck_tile
::
HostTensor
<
BDataType
>
b_host
(
f_host_tensor_descriptor
(
K
,
N
,
stride_B
,
BLayout
{}));
ck_tile
::
HostTensor
<
CDataType
>
c_host_ref
(
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
HostTensor
<
CDataType
>
c_host_dev
(
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_host
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_host
);
...
@@ -259,6 +254,9 @@ int main(int argc, char* argv[])
...
@@ -259,6 +254,9 @@ int main(int argc, char* argv[])
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmShape
,
ALayout
,
BLayout
,
CLayout
,
kPadA
,
kPadA
,
kPadB
,
kPadB
,
kPadC
>
;
kPadC
>
;
...
@@ -266,9 +264,9 @@ int main(int argc, char* argv[])
...
@@ -266,9 +264,9 @@ int main(int argc, char* argv[])
using
CodegenGemmPipeline
=
ck_tile
::
BlockGemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenGemmPipeline
=
ck_tile
::
BlockGemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
invoke_gemm
<
ck_tile
::
half_t
,
invoke_gemm
<
ck_tile
::
half_t
,
matrix_a_l
ayout
,
AL
ayout
,
matrix_b_l
ayout
,
BL
ayout
,
matrix_c_l
ayout
,
CL
ayout
,
CodegenPipelineProblem
,
CodegenPipelineProblem
,
CodegenGemmPipeline
,
CodegenGemmPipeline
,
CodegenGemmShape
>
(
a_buf
,
b_buf
,
c_buf
,
arg_parser
);
CodegenGemmShape
>
(
a_buf
,
b_buf
,
c_buf
,
arg_parser
);
...
@@ -280,17 +278,12 @@ int main(int argc, char* argv[])
...
@@ -280,17 +278,12 @@ int main(int argc, char* argv[])
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
{
{
// ToDo: Will Add the Element Op (bias) verification in the future.
// ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile
::
reference_gemm
<
ADataType
,
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
BDataType
,
a_host
,
b_host
,
c_host_ref
);
AccDataType
,
CDataType
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
(
a_host
,
b_host
,
c_host_ref
);
pass_cpu
=
ck_tile
::
check_err
(
c_host_dev
,
c_host_ref
);
pass_cpu
=
ck_tile
::
check_err
(
c_host_dev
,
c_host_ref
);
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass_cpu
?
"correct"
:
"fail"
)
std
::
cout
<<
"The CPU ve
r
ification result is:"
<<
(
pass_cpu
?
"correct"
:
"fail"
)
<<
std
::
flush
;
<<
std
::
flush
;
}
}
...
@@ -298,57 +291,19 @@ int main(int argc, char* argv[])
...
@@ -298,57 +291,19 @@ int main(int argc, char* argv[])
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
{
ck_tile
::
index_t
stride_a
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
ck_tile
::
index_t
stride_b
=
arg_parser
.
get_int
(
"stride_b"
);
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
index_t
stride_c
=
arg_parser
.
get_int
(
"stride_c"
);
if
(
stride_a
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
matrix_a_layout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
stride_a
=
M
;
}
else
{
stride_a
=
K
;
}
}
if
(
stride_b
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
matrix_b_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
stride_b
=
N
;
}
else
{
stride_b
=
K
;
}
}
if
(
stride_c
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
matrix_c_layout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
stride_c
=
M
;
}
else
{
stride_c
=
N
;
}
}
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
c_gpu_buf
.
SetZero
();
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_
a
,
stride_
b
,
stride_
c
);
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_
A
,
stride_
B
,
stride_
C
);
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
c_
gpu_
buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
pass_gpu
=
ck_tile
::
check_err
(
c_host_dev
,
c_host_gpu_ref
);
pass_gpu
=
ck_tile
::
check_err
(
c_host_dev
,
c_host_gpu_ref
);
std
::
cout
<<
"The GPU veification result is: "
<<
(
pass_gpu
?
"correct"
:
"fail"
)
std
::
cout
<<
"The GPU ve
r
ification result is: "
<<
(
pass_gpu
?
"correct"
:
"fail"
)
<<
std
::
flush
;
<<
std
::
flush
;
}
}
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
8bd49370
...
@@ -4,12 +4,10 @@
...
@@ -4,12 +4,10 @@
#pragma once
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include <string>
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
GemmBasicTypeConfig
;
struct
GemmBasicTypeConfig
;
...
@@ -58,7 +56,6 @@ struct gemm_basic_args
...
@@ -58,7 +56,6 @@ struct gemm_basic_args
const
void
*
p_a
;
const
void
*
p_a
;
const
void
*
p_b
;
const
void
*
p_b
;
void
*
p_c
;
void
*
p_c
;
float
epsilon
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
N
;
...
...
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
View file @
8bd49370
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <cstring>
#include <cstring>
...
@@ -11,20 +10,24 @@
...
@@ -11,20 +10,24 @@
#include <string>
#include <string>
#include <tuple>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"b"
,
"1"
,
"batch size"
)
arg_parser
.
insert
(
"b"
,
"1"
,
"batch size"
)
.
insert
(
"m"
,
"
1024
"
,
"m dimension"
)
.
insert
(
"m"
,
"
3840
"
,
"m dimension"
)
.
insert
(
"n"
,
"
2048
"
,
"n dimension"
)
.
insert
(
"n"
,
"
4096
"
,
"n dimension"
)
.
insert
(
"k"
,
"
6
4"
,
"k dimension"
)
.
insert
(
"k"
,
"4
096
"
,
"k dimension"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"e"
,
"1e-5"
,
"Absolute error tolerance"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"
1
0"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"warmup"
,
"
5
0"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
...
@@ -32,7 +35,7 @@ auto create_args(int argc, char* argv[])
...
@@ -32,7 +35,7 @@ auto create_args(int argc, char* argv[])
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
}
template
<
typename
Layout
A
,
typename
Layout
B
,
typename
Layout
C
>
template
<
typename
A
Layout
,
typename
B
Layout
,
typename
C
Layout
>
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// ToDo: This will be modified by the codegen code later.
// ToDo: This will be modified by the codegen code later.
...
@@ -62,139 +65,180 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -62,139 +65,180 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
PipelineProblem
=
ck_tile
::
BlockGemmUniversalPipelineProblem
<
ADataType
,
BDataType
,
CDataType
,
GemmShape
,
kPadA
,
kPadB
,
kPadC
,
ck_tile
::
BlockGemmPipelineScheduler
::
Intrawave
>
;
// The GemmPipeline should also come from the Codegen.
using
GemmPipeline
=
ck_tile
::
BlockGemmPipelineAgBgCrMem
<
PipelineProblem
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
,
LayoutA
,
LayoutB
,
LayoutC
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
template
<
typename
DataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_buf
,
ck_tile
::
DeviceMem
&
b_buf
,
ck_tile
::
DeviceMem
&
c_buf
,
const
ck_tile
::
ArgParser
&
arg_parser
)
{
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
false
,
kPadC
>>
;
if
(
data_type
!=
DataTypeTraits
<
DataType
>::
name
)
{
using
BaseGemmPipeline
=
std
::
cerr
<<
"Data type mismatch: expected "
<<
DataTypeTraits
<
DataType
>::
name
<<
", got "
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
BlockGemmPipelineProblem
<
ADataType
,
<<
data_type
<<
std
::
endl
;
BDataType
,
return
-
1
;
// Or handle the error appropriately
CDataType
,
}
GemmShape
,
ALayout
,
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
BLayout
,
ck_tile
::
index_t
batch_size
=
arg_parser
.
get_int
(
"b"
);
CLayout
>>
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
ck_tile
::
index_t
stride_a
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
index_t
stride_b
=
arg_parser
.
get_int
(
"stride_b"
);
float
ave_time
{
0
};
ck_tile
::
index_t
stride_c
=
arg_parser
.
get_int
(
"stride_c"
);
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
using
GemmKernel
=
ck_tile
::
remove_cvref_t
<
decltype
(
kernel
)
>
;
auto
kargs
=
GemmKernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
GemmKernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
GemmKernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Lunching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
gemm_basic_args
args
;
ave_time
=
ck_tile
::
launch_kernel
(
args
.
p_a
=
a_buf
.
GetDeviceBuffer
();
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
kernel
,
grids
,
blocks
,
0
,
kargs
));
args
.
p_b
=
b_buf
.
GetDeviceBuffer
();
};
args
.
p_c
=
c_buf
.
GetDeviceBuffer
();
args
.
epsilon
=
epsilon
;
#define RUN_KERNEL_(has_hot_loop_, tail_number_) \
args
.
kbatch
=
batch_size
;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< \
args
.
M
=
M
;
ck_tile::UniversalGemmPipelineProblem<ADataType, \
args
.
N
=
N
;
BDataType, \
args
.
K
=
K
;
CDataType, \
GemmShape, \
// Only set stride_M and stride_N if they are non-zero and not equal to K.
ALayout, \
if
(
stride_a
!=
0
)
BLayout, \
CLayout, \
kPadA, \
kPadB, \
kPadC, \
ck_tile::GemmPipelineScheduler::Intrawave, \
has_hot_loop_, \
tail_number_>>; \
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; \
Run(Kernel{});
if
(
has_hot_loop
)
{
{
args
.
stride_A
=
stride_a
;
// Tail pipeline One to Seven
}
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
else
{
if
constexpr
(
std
::
is_same_v
<
LayoutA
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
args
.
stride_A
=
K
;
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
One
)
;
}
}
else
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
{
args
.
stride_A
=
M
;
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Full
)
;
}
}
}
if
(
stride_b
!=
0
)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
2
)
{
args
.
stride_B
=
stride_b
;
}
else
{
if
constexpr
(
std
::
is_same_v
<
LayoutB
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
args
.
stride_B
=
N
;
if
(
tail_num
==
ck_tile
::
TailNumber
::
Two
)
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Two
);
}
}
}
else
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
3
)
{
{
args
.
stride_B
=
K
;
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Three
);
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
4
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Four
)
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Four
);
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
5
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Five
)
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Five
);
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
6
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Six
)
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Six
);
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
7
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Seven
)
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Seven
);
}
}
}
}
if
(
stride_c
!=
0
)
{
args
.
stride_C
=
stride_c
;
}
}
else
else
{
{
if
constexpr
(
std
::
is_same_v
<
LayoutC
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
// Tail number always 1
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
args
.
stride_C
=
N
;
}
else
{
{
args
.
stride_C
=
M
;
RUN_KERNEL_
(
false
,
ck_tile
::
TailNumber
::
One
)
;
}
}
}
}
float
ave_time
=
#undef RUN_KERNEL_
gemm_calc
<
LayoutA
,
LayoutB
,
LayoutC
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
});
return
ave_time
;
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
K
,
ck_tile
::
index_t
stride_A
,
ck_tile
::
index_t
stride_B
,
ck_tile
::
index_t
stride_C
,
ck_tile
::
index_t
kbatch
,
int
n_warmup
,
int
n_repeat
)
{
gemm_basic_args
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
kbatch
=
kbatch
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
stride_A
=
stride_A
;
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
float
ave_time
=
gemm_calc
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
string
op_name
{
"Gemm{MemBoundPipeline}"
};
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_byte
=
std
::
size_t
num_byte
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
N
*
K
+
sizeof
(
CDataType
)
*
M
*
N
;
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
N
*
K
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"The overall perfomance of the GEMM with "
std
::
cout
<<
"Run "
<<
op_name
<<
"kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
"["
<<
data_type
<<
"]"
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
"batch size: "
<<
batch_size
<<
". m:"
<<
M
<<
",n:"
<<
N
<<
", k:"
<<
K
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
"is:
\n
"
;
<<
std
::
endl
;
std
::
cout
<<
"Running time :"
<<
ave_time
<<
"ms, Throughput"
<<
gb_per_sec
<<
"GB/s
\n
"
<<
std
::
flush
;
return
ave_time
;
return
ave_time
;
}
}
...
@@ -209,118 +253,120 @@ int main(int argc, char* argv[])
...
@@ -209,118 +253,120 @@ int main(int argc, char* argv[])
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N).
ck_tile
::
index_t
stride_A
=
arg_parser
.
get_int
(
"stride_a"
);
using
matrix_a_layout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
using
matrix_b_layout
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
using
matrix_c_layout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
// host verify
std
::
vector
<
int
>
a_dimensions
=
(
std
::
is_same_v
<
matrix_a_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
?
std
::
vector
<
int
>
{
M
,
K
}
:
std
::
vector
<
int
>
{
K
,
M
};
std
::
vector
<
int
>
b_dimensions
=
(
std
::
is_same_v
<
matrix_b_layout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
std
::
vector
<
int
>
{
N
,
K
}
:
std
::
vector
<
int
>
{
K
,
N
};
std
::
vector
<
int
>
c_dimensions
=
(
std
::
is_same_v
<
matrix_c_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
?
std
::
vector
<
int
>
{
M
,
N
}
:
std
::
vector
<
int
>
{
N
,
M
};
ck_tile
::
HostTensor
<
ADataType
>
a_host
(
a_dimensions
);
ck_tile
::
HostTensor
<
BDataType
>
b_host
(
b_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_ref
(
c_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_dev
(
c_dimensions
);
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_host
);
ck_tile
::
DeviceMem
a_buf
(
a_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_buf
(
b_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_buf
(
c_host_dev
.
get_element_space_size_in_bytes
());
a_buf
.
ToDevice
(
a_host
.
data
());
b_buf
.
ToDevice
(
b_host
.
data
());
invoke_gemm
<
ck_tile
::
half_t
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
(
a_buf
,
b_buf
,
c_buf
,
arg_parser
);
c_buf
.
FromDevice
(
c_host_dev
.
data
());
bool
pass
=
true
;
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
ck_tile
::
index_t
batch_size
=
arg_parser
.
get_int
(
"b"
);
{
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
// ToDo: Will Add the Element Op (bias) verification in the future.
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
(
a_host
,
b_host
,
c_host_ref
);
pass
=
ck_tile
::
check_err
(
c_host_dev
,
c_host_ref
);
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
flush
;
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
ck_tile
::
index_t
stride_a
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
index_t
stride_b
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_c
=
arg_parser
.
get_int
(
"stride_c"
);
if
(
stride_a
==
0
)
using
ALayout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
{
using
BLayout
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
if
constexpr
(
std
::
is_same_v
<
matrix_a_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
using
CLayout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
{
stride_a
=
K
;
}
else
{
stride_a
=
M
;
}
}
if
(
stride_b
==
0
)
using
namespace
ck_tile
::
literals
;
{
if
constexpr
(
std
::
is_same_v
<
matrix_b_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
stride_b
=
N
;
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
})
;
}
}
else
else
{
{
stride_b
=
K
;
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
})
;
}
}
}
}
;
if
(
stride_c
==
0
)
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
{
if
constexpr
(
std
::
is_same_v
<
matrix_c_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
stride_c
=
N
;
return
col
;
}
}
else
else
{
{
stride_c
=
M
;
return
row
;
}
}
}
}
else
return
stride
;
};
stride_A
=
f_get_default_stride
(
M
,
K
,
stride_A
,
ALayout
{});
stride_B
=
f_get_default_stride
(
K
,
N
,
stride_B
,
BLayout
{});
stride_C
=
f_get_default_stride
(
M
,
N
,
stride_C
,
CLayout
{});
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
stride_A
,
ALayout
{}));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
stride_B
,
BLayout
{}));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
// TODO: add different init types
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_m_n_dev_buf
(
c_m_n_dev_result
.
get_element_space_size_in_bytes
());
a_m_k_dev_buf
.
ToDevice
(
a_m_k
.
data
());
b_k_n_dev_buf
.
ToDevice
(
b_k_n
.
data
());
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
invoke_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
,
batch_size
,
n_warmup
,
n_repeat
);
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
c_m_n_host_ref
.
SetZero
(
);
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
pass
=
ck_tile
::
check_err
(
c_
host_dev
,
c_host_gpu
_ref
);
pass
=
ck_tile
::
check_err
(
c_
m_n_dev_result
,
c_m_n_host
_ref
);
std
::
cout
<<
"The
G
PU veification result is:
"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
flush
;
std
::
cout
<<
"The
C
PU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
DeviceMem
c_m_n_gpu_buf_ref
(
c_m_n_gpu_ref
.
get_element_space_size_in_bytes
());
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_gpu_buf_ref
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
);
std
::
cout
<<
std
::
endl
<<
std
::
flush
;
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
std
::
cout
<<
"The GPU veification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
return
pass
;
return
pass
;
}
}
include/ck_tile/core/utility/literals.hpp
0 → 100644
View file @
8bd49370
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
namespace
ck_tile
{
namespace
literals
{
// [P0330] Literal Suffix for (signed) size_t (C++23)
// ref: https://wg21.link/p0330r8
inline
constexpr
std
::
size_t
operator
""
_uz
(
unsigned
long
long
size
)
{
return
static_cast
<
std
::
size_t
>
(
size
);
}
inline
constexpr
std
::
size_t
operator
""
_zu
(
unsigned
long
long
size
)
{
return
static_cast
<
std
::
size_t
>
(
size
);
}
}
// namespace literals
}
// namespace ck_tile
include/ck_tile/host/reference/reference_gemm.hpp
View file @
8bd49370
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <thread>
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -14,48 +15,36 @@ template <typename ADataType,
...
@@ -14,48 +15,36 @@ template <typename ADataType,
typename
BDataType
,
typename
BDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
,
typename
AElementOp
=
ck_tile
::
identity
,
typename
AElementOp
=
ck_tile
::
identity
,
typename
BElementOp
=
ck_tile
::
identity
,
typename
BElementOp
=
ck_tile
::
identity
,
typename
ACCElementOp
=
ck_tile
::
identity
>
typename
ACCElementOp
=
ck_tile
::
identity
>
CK_TILE_HOST
void
reference_gemm
(
const
HostTensor
<
ADataType
>&
a_m_k
,
CK_TILE_HOST
void
reference_gemm
(
const
HostTensor
<
ADataType
>&
a_m_k
,
const
HostTensor
<
BDataType
>&
b_
n_k
,
const
HostTensor
<
BDataType
>&
b_
k_n
,
HostTensor
<
CDataType
>&
c_m_n
,
HostTensor
<
CDataType
>&
c_m_n
,
const
AElementOp
&
a_element_op
=
{},
const
AElementOp
&
a_element_op
=
{},
const
BElementOp
&
b_element_op
=
{},
const
BElementOp
&
b_element_op
=
{},
const
ACCElementOp
&
acc_element_op
=
{})
const
ACCElementOp
&
acc_element_op
=
{})
{
{
const
int
N
=
b_n_k
.
mDesc
.
get_lengths
()[
0
];
const
std
::
size_t
M
=
a_m_k
.
get_length
(
0
);
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
const
std
::
size_t
N
=
b_k_n
.
get_length
(
1
);
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
const
std
::
size_t
K
=
a_m_k
.
get_length
(
1
);
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
const
int
M
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
0
]
:
a_m_k
.
mDesc
.
get_lengths
()[
1
];
auto
f
=
[
&
](
auto
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
AccDataType
v_acc
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
auto
f_mn
=
[
&
](
auto
m
,
auto
n
)
{
{
AccDataType
v_acc
=
0
;
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_element_op
(
a_m_k
(
m
,
k
))
:
a_element_op
(
a_m_k
(
k
,
m
));
BDataType
v_b
=
b_element_op
(
b_n_k
(
n
,
k
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
{
}
ADataType
v_a
=
a_element_op
(
a_m_k
(
m
,
k
));
BDataType
v_b
=
b_element_op
(
b_k_n
(
k
,
n
));
c_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
c_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
};
};
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
_mn
,
M
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
8bd49370
...
@@ -8,34 +8,29 @@
...
@@ -8,34 +8,29 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
typename
GemmPipeline_
,
typename
EpiloguePipeline_
,
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
GemmKernel
struct
GemmKernel
{
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
Layout
A
=
remove_cvref_t
<
Layout
A_
>
;
using
A
Layout
=
remove_cvref_t
<
typename
GemmPipeline
::
A
Layout
>
;
using
Layout
B
=
remove_cvref_t
<
Layout
B_
>
;
using
B
Layout
=
remove_cvref_t
<
typename
GemmPipeline
::
B
Layout
>
;
using
Layout
C
=
remove_cvref_t
<
Layout
C_
>
;
using
C
Layout
=
remove_cvref_t
<
typename
GemmPipeline
::
C
Layout
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
//
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using
C
O
DataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
O
DataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
C
DataType
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M
_size
,
index_t
N
_size
,
index_t
Batch
_size
)
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
K
Batch
)
{
{
return
TilePartitioner
::
GridSize
(
M
_size
,
N_size
,
Batch
_size
);
return
TilePartitioner
::
GridSize
(
M
,
N
,
K
Batch
);
}
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
...
@@ -45,30 +40,30 @@ struct GemmKernel
...
@@ -45,30 +40,30 @@ struct GemmKernel
const
void
*
a_ptr
;
const
void
*
a_ptr
;
const
void
*
b_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
void
*
c_ptr
;
ck_tile
::
index_t
M
;
index_t
M
;
ck_tile
::
index_t
N
;
index_t
N
;
ck_tile
::
index_t
K
;
index_t
K
;
ck_tile
::
index_t
stride_A
;
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
index_t
stride_C
;
};
};
CK_TILE_HOST
static
constexpr
GemmCommonKargs
MakeKargs
(
const
void
*
a_ptr
,
CK_TILE_HOST
static
constexpr
GemmCommonKargs
MakeKargs
(
const
void
*
a_ptr
,
const
void
*
b_ptr
,
const
void
*
b_ptr
,
void
*
c_ptr
,
void
*
c_ptr
,
ck_tile
::
index_t
M
,
index_t
M
,
ck_tile
::
index_t
N
,
index_t
N
,
ck_tile
::
index_t
K
,
index_t
K
,
ck_tile
::
index_t
stride_A
,
index_t
stride_A
,
ck_tile
::
index_t
stride_B
,
index_t
stride_B
,
ck_tile
::
index_t
stride_C
)
index_t
stride_C
)
{
{
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
}
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
return
ck_tile
::
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
}
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
...
@@ -79,12 +74,12 @@ struct GemmKernel
...
@@ -79,12 +74,12 @@ struct GemmKernel
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
A
,
tensor_layout
::
gemm
::
Column
Major
>
)
if
constexpr
(
std
::
is_same_v
<
A
Layout
,
tensor_layout
::
gemm
::
Row
Major
>
)
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
AlignmentA
>
{},
number
<
GemmPipeline
::
AlignmentA
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
...
@@ -93,14 +88,14 @@ struct GemmKernel
...
@@ -93,14 +88,14 @@ struct GemmKernel
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
GemmPipeline
::
AlignmentA
>
{},
number
<
GemmPipeline
::
AlignmentA
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
}();
}();
auto
b_tensor_view
=
[
&
]()
{
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
B
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
B
Layout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
b_start
,
...
@@ -110,7 +105,7 @@ struct GemmKernel
...
@@ -110,7 +105,7 @@ struct GemmKernel
number
<
1
>
{});
number
<
1
>
{});
}
}
else
else
{
// Default NK layout
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
N
,
kargs
.
K
),
...
@@ -123,8 +118,8 @@ struct GemmKernel
...
@@ -123,8 +118,8 @@ struct GemmKernel
auto
a_pad_view
=
pad_tensor_view
(
auto
a_pad_view
=
pad_tensor_view
(
a_tensor_view
,
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
sequence
<
false
,
GemmPipeline
::
kPadA
?
1
:
0
>
{});
GemmPipeline
::
kPadA
?
true
:
false
>
{});
auto
ABlockWindow
=
make_tile_window
(
auto
ABlockWindow
=
make_tile_window
(
a_pad_view
,
a_pad_view
,
...
@@ -134,8 +129,8 @@ struct GemmKernel
...
@@ -134,8 +129,8 @@ struct GemmKernel
auto
b_pad_view
=
pad_tensor_view
(
auto
b_pad_view
=
pad_tensor_view
(
b_tensor_view
,
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
sequence
<
false
,
GemmPipeline
::
kPadB
?
1
:
0
>
{});
GemmPipeline
::
kPadB
?
true
:
false
>
{});
auto
BBlockWindow
=
make_tile_window
(
auto
BBlockWindow
=
make_tile_window
(
b_pad_view
,
b_pad_view
,
...
@@ -225,15 +220,15 @@ struct GemmKernel
...
@@ -225,15 +220,15 @@ struct GemmKernel
}
}
}
}
C
O
DataType
*
c_start
=
static_cast
<
C
O
DataType
*>
(
kargs
.
c_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
C
,
tensor_layout
::
gemm
::
Column
Major
>
)
if
constexpr
(
std
::
is_same_v
<
C
Layout
,
tensor_layout
::
gemm
::
Row
Major
>
)
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
AlignmentC
>
{},
number
<
GemmPipeline
::
AlignmentC
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
...
@@ -242,7 +237,7 @@ struct GemmKernel
...
@@ -242,7 +237,7 @@ struct GemmKernel
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
GemmPipeline
::
AlignmentC
>
{},
number
<
GemmPipeline
::
AlignmentC
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
...
@@ -251,13 +246,13 @@ struct GemmKernel
...
@@ -251,13 +246,13 @@ struct GemmKernel
auto
c_pad_view
=
pad_tensor_view
(
auto
c_pad_view
=
pad_tensor_view
(
c_tensor_view
,
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
0
,
sequence
<
false
,
GemmPipeline
::
kPadC
?
1
:
0
>
{});
GemmPipeline
::
kPadC
?
true
:
false
>
{});
auto
CBlockWindow
_pad
=
make_tile_window
(
auto
CBlockWindow
=
make_tile_window
(
c_pad_view
,
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow
_pad
,
c_block_tile
);
EpiloguePipeline
{}(
CBlockWindow
,
c_block_tile
);
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_mem.hpp
deleted
100644 → 0
View file @
d3689b06
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
BlockGemmPipelineAgBgCrDefaultPolicy
>
struct
BlockGemmPipelineAgBgCrMem
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
CBlockTile
=
typename
BlockGemm
::
CBlockTile
;
using
I0
=
number
<
0
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
AlignmentA
=
Problem
::
AlignmentA
;
static
constexpr
index_t
AlignmentB
=
Problem
::
AlignmentB
;
static
constexpr
index_t
AlignmentC
=
Problem
::
AlignmentC
;
static
constexpr
bool
kPadA
=
Problem
::
kPadA
;
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
get_warp_size
()
/
BlockSize
)
>=
1
?
4
*
get_warp_size
()
/
BlockSize
:
1
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
integer_divide_ceil
(
32768
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
?
FullMemBandPrefetchStages
<=
8
?
FullMemBandPrefetchStages
:
8
:
2
;
static
constexpr
index_t
LocalPrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
CK_TILE_HOST_DEVICE
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
CK_TILE_HOST_DEVICE
static
constexpr
TailNumber
GetBlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
return
TailNumber
::
One
;
}
else
if
(
num_loop
%
PrefetchStages
==
2
)
{
return
TailNumber
::
Two
;
}
else
if
(
num_loop
%
PrefetchStages
==
3
)
{
return
TailNumber
::
Three
;
}
else
if
(
num_loop
%
PrefetchStages
==
4
)
{
return
TailNumber
::
Four
;
}
else
if
(
num_loop
%
PrefetchStages
==
5
)
{
return
TailNumber
::
Five
;
}
else
if
(
num_loop
%
PrefetchStages
==
6
)
{
return
TailNumber
::
Six
;
}
else
if
(
num_loop
%
PrefetchStages
==
7
)
{
return
TailNumber
::
Seven
;
}
else
{
return
TailNumber
::
Full
;
}
}
CK_TILE_HOST_DEVICE
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
{
return
ck_tile
::
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
+
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
BlockGemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
{
};
template
<
>
struct
PipelineImpl
<
BlockGemmPipelineScheduler
::
Intrawave
>
{
template
<
typename
BlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
BlockTile
&
block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
{
load_tile_raw
(
block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
KPerBlock
});
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
CK_TILE_DEVICE
void
LocalPrefill
(
DstTileWindow
&
lds_tile_window
,
const
SrcBlockTile
&
src_block_tile
,
const
ElementFunction
&
element_func
)
const
{
const
auto
block_tile_tmp
=
tile_elementwise_in
(
element_func
,
src_block_tile
);
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
,
CBlockTile
&
c_block_tile
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// B LDS tile for block GEMM
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
constexpr
auto
block_gemm
=
BlockGemm
();
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
// prefetch
// global read 0
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
// Global prefetch [2, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
// main body
if
constexpr
(
HasHotLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
auto
HotLoopTail
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
});
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
};
// TODO: TailNumber2Number
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
HotLoopTail
(
number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
HotLoopTail
(
number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
HotLoopTail
(
number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
HotLoopTail
(
number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
HotLoopTail
(
number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
HotLoopTail
(
number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
HotLoopTail
(
number
<
PrefetchStages
>
{});
}
return
c_block_tile
;
}
};
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
,
CBlockTile
&
c_block_tile
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
a_element_func
,
b_dram_block_window_tmp
,
b_element_func
,
num_loop
,
p_smem
,
c_block_tile
);
}
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
index_t
num_loop
,
void
*
p_smem
,
CBlockTile
&
c_block_tile
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem
,
c_block_tile
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
View file @
8bd49370
...
@@ -13,6 +13,9 @@ template <typename ADataType_,
...
@@ -13,6 +13,9 @@ template <typename ADataType_,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
,
bool
kPadA_
=
false
,
bool
kPadA_
=
false
,
bool
kPadB_
=
false
,
bool
kPadB_
=
false
,
bool
kPadC_
=
false
>
bool
kPadC_
=
false
>
...
@@ -23,6 +26,10 @@ struct BlockGemmPipelineProblem
...
@@ -23,6 +26,10 @@ struct BlockGemmPipelineProblem
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
ALayout
=
remove_cvref_t
<
ALayout_
>
;
using
BLayout
=
remove_cvref_t
<
BLayout_
>
;
using
CLayout
=
remove_cvref_t
<
CLayout_
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadB
=
kPadB_
;
...
@@ -37,18 +44,29 @@ template <typename ADataType_,
...
@@ -37,18 +44,29 @@ template <typename ADataType_,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
,
bool
kPadA_
=
false
,
typename
ALayout_
,
bool
kPadB_
=
false
,
typename
BLayout_
,
bool
kPadC_
=
false
,
typename
CLayout_
,
BlockGemmPipelineScheduler
Scheduler_
=
BlockGemmPipelineScheduler
::
Intrawave
>
bool
kPadA_
=
false
,
struct
BlockGemmUniversalPipelineProblem
bool
kPadB_
=
false
,
bool
kPadC_
=
false
,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
bool
HasHotLoop_
=
false
,
TailNumber
TailNum_
=
TailNumber
::
Full
>
struct
UniversalGemmPipelineProblem
{
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
ALayout
=
remove_cvref_t
<
ALayout_
>
;
using
BLayout
=
remove_cvref_t
<
BLayout_
>
;
using
CLayout
=
remove_cvref_t
<
CLayout_
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadA
=
kPadA_
;
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_ag_bg_cr_default_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_default_policy.hpp
View file @
8bd49370
...
@@ -4,12 +4,12 @@
...
@@ -4,12 +4,12 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV1
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
// Default policy class should not be templated, put template on member functions instead
using
Block
GemmPipelineAgBgCrDefaultPolicy
=
Block
GemmPipelineAgBgCrMemCustomPolicy
;
using
GemmPipelineAgBgCrDefaultPolicy
=
GemmPipelineAgBgCrMemCustomPolicy
;
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp
View file @
8bd49370
...
@@ -9,14 +9,14 @@
...
@@ -9,14 +9,14 @@
namespace
ck_tile
{
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV1
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2
// GlobalPrefetchStages: >=2
// LocalPreFillStages: 1
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
// LocalSharedMemoryBuffer: 1
struct
Block
GemmPipelineAgBgCrMemCustomPolicy
struct
GemmPipelineAgBgCrMemCustomPolicy
{
{
// 3d + padding
// 3d + padding
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -47,8 +47,6 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
...
@@ -47,8 +47,6 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
{
using
namespace
ck_tile
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
...
@@ -69,7 +67,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
...
@@ -69,7 +67,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeA
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
...
@@ -77,7 +75,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
...
@@ -77,7 +75,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeB
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
...
@@ -85,7 +83,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
...
@@ -85,7 +83,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_ag_bg_cr_scheduler.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
View file @
8bd49370
...
@@ -3,9 +3,11 @@
...
@@ -3,9 +3,11 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
enum
struct
Block
GemmPipelineScheduler
enum
struct
GemmPipelineScheduler
{
{
Intrawave
,
Intrawave
,
Interwave
,
Interwave
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment