Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8bd49370
Commit
8bd49370
authored
Oct 01, 2024
by
Adam Osewski
Browse files
Refactoring & Move Layout info to pipeline problem.
parent
d3689b06
Changes
11
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
469 additions
and
841 deletions
+469
-841
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+97
-142
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+2
-5
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
+253
-207
include/ck_tile/core/utility/literals.hpp
include/ck_tile/core/utility/literals.hpp
+22
-0
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+19
-30
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+42
-47
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_mem.hpp
...le/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_mem.hpp
+0
-394
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
...ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
+23
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_default_policy.hpp
...s/gemm/pipeline/gemm_pipeline_ag_bg_cr_default_policy.hpp
+3
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp
...emm/pipeline/gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp
+5
-7
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
...le/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
+3
-1
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
8bd49370
...
...
@@ -2,7 +2,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
#include <hip/hip_runtime.h>
#include <cstring>
...
...
@@ -11,6 +10,11 @@
#include <string>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
...
...
@@ -22,7 +26,6 @@ auto create_args(int argc, char* argv[])
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"e"
,
"1e-5"
,
"Absolute error tolerance"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"10"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
...
...
@@ -51,13 +54,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
,
LayoutA
,
LayoutB
,
LayoutC
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
epsilon
,
args
.
M
,
args
.
N
,
args
.
K
,
...
...
@@ -96,7 +97,6 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
return
-
1
;
// Or handle the error appropriately
}
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
ck_tile
::
index_t
batch_size
=
arg_parser
.
get_int
(
"b"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
...
...
@@ -110,66 +110,34 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
args
.
p_a
=
a_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_buf
.
GetDeviceBuffer
();
args
.
epsilon
=
epsilon
;
args
.
kbatch
=
batch_size
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
// Only set stride_M and stride_N if they are non-zero and not equal to K.
if
(
stride_a
!=
0
)
{
args
.
stride_A
=
stride_a
;
}
else
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
args
.
stride_A
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
L
ayout
A
,
ck_tile
::
tensor_layout
::
gemm
::
Column
Major
>
)
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
l
ayout
)
,
ck_tile
::
tensor_layout
::
gemm
::
Row
Major
>
)
{
return
M
;
return
col
;
}
else
{
return
K
;
return
row
;
}
}();
}
if
(
stride_b
!=
0
)
{
args
.
stride_B
=
stride_b
;
}
else
{
args
.
stride_B
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
LayoutB
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
N
;
}
else
{
return
K
;
}
}();
}
return
stride
;
};
if
(
stride_c
!=
0
)
{
args
.
stride_C
=
stride_c
;
}
else
{
args
.
stride_C
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
LayoutC
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
M
;
}
else
{
return
N
;
}
}();
}
args
.
stride_A
=
f_get_default_stride
(
M
,
K
,
stride_a
,
LayoutA
{});
args
.
stride_B
=
f_get_default_stride
(
K
,
N
,
stride_b
,
LayoutB
{});
args
.
stride_C
=
f_get_default_stride
(
M
,
N
,
stride_c
,
LayoutC
{});
float
ave_time
=
gemm_calc
<
LayoutA
,
LayoutB
,
LayoutC
,
PipelineProblem
,
GemmPipeline
,
GemmShape
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
});
...
...
@@ -197,30 +165,57 @@ int main(int argc, char* argv[])
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N).
using
matrix_a_layout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
matrix_b_layout
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
matrix_c_layout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
// host verify
std
::
vector
<
int
>
a_dimensions
=
(
std
::
is_same_v
<
matrix_a_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
?
std
::
vector
<
int
>
{
M
,
K
}
:
std
::
vector
<
int
>
{
K
,
M
};
std
::
vector
<
int
>
b_dimensions
=
(
std
::
is_same_v
<
matrix_b_layout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
std
::
vector
<
int
>
{
N
,
K
}
:
std
::
vector
<
int
>
{
K
,
N
};
std
::
vector
<
int
>
c_dimensions
=
(
std
::
is_same_v
<
matrix_c_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
?
std
::
vector
<
int
>
{
M
,
N
}
:
std
::
vector
<
int
>
{
N
,
M
};
ck_tile
::
HostTensor
<
ADataType
>
a_host
(
a_dimensions
);
ck_tile
::
HostTensor
<
BDataType
>
b_host
(
b_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_ref
(
c_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_dev
(
c_dimensions
);
ck_tile
::
index_t
stride_A
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
using
ALayout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
namespace
ck_tile
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
stride_A
=
f_get_default_stride
(
M
,
K
,
stride_A
,
ALayout
{});
stride_B
=
f_get_default_stride
(
K
,
N
,
stride_B
,
BLayout
{});
stride_C
=
f_get_default_stride
(
M
,
N
,
stride_C
,
CLayout
{});
ck_tile
::
HostTensor
<
ADataType
>
a_host
(
f_host_tensor_descriptor
(
M
,
K
,
stride_A
,
ALayout
{}));
ck_tile
::
HostTensor
<
BDataType
>
b_host
(
f_host_tensor_descriptor
(
K
,
N
,
stride_B
,
BLayout
{}));
ck_tile
::
HostTensor
<
CDataType
>
c_host_ref
(
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
HostTensor
<
CDataType
>
c_host_dev
(
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_host
);
...
...
@@ -259,6 +254,9 @@ int main(int argc, char* argv[])
BDataType
,
AccDataType
,
CodegenGemmShape
,
ALayout
,
BLayout
,
CLayout
,
kPadA
,
kPadB
,
kPadC
>
;
...
...
@@ -266,9 +264,9 @@ int main(int argc, char* argv[])
using
CodegenGemmPipeline
=
ck_tile
::
BlockGemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
invoke_gemm
<
ck_tile
::
half_t
,
matrix_a_l
ayout
,
matrix_b_l
ayout
,
matrix_c_l
ayout
,
AL
ayout
,
BL
ayout
,
CL
ayout
,
CodegenPipelineProblem
,
CodegenGemmPipeline
,
CodegenGemmShape
>
(
a_buf
,
b_buf
,
c_buf
,
arg_parser
);
...
...
@@ -280,17 +278,12 @@ int main(int argc, char* argv[])
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
{
// ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
(
a_host
,
b_host
,
c_host_ref
);
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_host
,
b_host
,
c_host_ref
);
pass_cpu
=
ck_tile
::
check_err
(
c_host_dev
,
c_host_ref
);
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass_cpu
?
"correct"
:
"fail"
)
std
::
cout
<<
"The CPU ve
r
ification result is:"
<<
(
pass_cpu
?
"correct"
:
"fail"
)
<<
std
::
flush
;
}
...
...
@@ -298,57 +291,19 @@ int main(int argc, char* argv[])
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
ck_tile
::
index_t
stride_a
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
index_t
stride_b
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_c
=
arg_parser
.
get_int
(
"stride_c"
);
if
(
stride_a
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
matrix_a_layout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
stride_a
=
M
;
}
else
{
stride_a
=
K
;
}
}
if
(
stride_b
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
matrix_b_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
stride_b
=
N
;
}
else
{
stride_b
=
K
;
}
}
if
(
stride_c
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
matrix_c_layout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
stride_c
=
M
;
}
else
{
stride_c
=
N
;
}
}
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
c_gpu_buf
.
SetZero
();
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_
a
,
stride_
b
,
stride_
c
);
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_
A
,
stride_
B
,
stride_
C
);
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
c_
gpu_
buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
pass_gpu
=
ck_tile
::
check_err
(
c_host_dev
,
c_host_gpu_ref
);
std
::
cout
<<
"The GPU veification result is: "
<<
(
pass_gpu
?
"correct"
:
"fail"
)
std
::
cout
<<
"The GPU ve
r
ification result is: "
<<
(
pass_gpu
?
"correct"
:
"fail"
)
<<
std
::
flush
;
}
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
8bd49370
...
...
@@ -4,12 +4,10 @@
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include <string>
template
<
typename
DataType
>
struct
GemmBasicTypeConfig
;
...
...
@@ -58,7 +56,6 @@ struct gemm_basic_args
const
void
*
p_a
;
const
void
*
p_b
;
void
*
p_c
;
float
epsilon
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
...
...
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
View file @
8bd49370
This diff is collapsed.
Click to expand it.
include/ck_tile/core/utility/literals.hpp
0 → 100644
View file @
8bd49370
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
namespace
ck_tile
{
namespace
literals
{
// [P0330] Literal Suffix for (signed) size_t (C++23)
// ref: https://wg21.link/p0330r8
inline
constexpr
std
::
size_t
operator
""
_uz
(
unsigned
long
long
size
)
{
return
static_cast
<
std
::
size_t
>
(
size
);
}
inline
constexpr
std
::
size_t
operator
""
_zu
(
unsigned
long
long
size
)
{
return
static_cast
<
std
::
size_t
>
(
size
);
}
}
// namespace literals
}
// namespace ck_tile
include/ck_tile/host/reference/reference_gemm.hpp
View file @
8bd49370
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <thread>
namespace
ck_tile
{
...
...
@@ -14,48 +15,36 @@ template <typename ADataType,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
,
typename
AElementOp
=
ck_tile
::
identity
,
typename
BElementOp
=
ck_tile
::
identity
,
typename
ACCElementOp
=
ck_tile
::
identity
>
CK_TILE_HOST
void
reference_gemm
(
const
HostTensor
<
ADataType
>&
a_m_k
,
const
HostTensor
<
BDataType
>&
b_
n_k
,
const
HostTensor
<
BDataType
>&
b_
k_n
,
HostTensor
<
CDataType
>&
c_m_n
,
const
AElementOp
&
a_element_op
=
{},
const
BElementOp
&
b_element_op
=
{},
const
ACCElementOp
&
acc_element_op
=
{})
{
const
int
N
=
b_n_k
.
mDesc
.
get_lengths
()[
0
];
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
const
int
M
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
0
]
:
a_m_k
.
mDesc
.
get_lengths
()[
1
];
auto
f
=
[
&
](
auto
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
std
::
size_t
M
=
a_m_k
.
get_length
(
0
);
const
std
::
size_t
N
=
b_k_n
.
get_length
(
1
);
const
std
::
size_t
K
=
a_m_k
.
get_length
(
1
);
auto
f_mn
=
[
&
](
auto
m
,
auto
n
)
{
AccDataType
v_acc
=
0
;
for
(
in
t
k
=
0
;
k
<
K
;
++
k
)
for
(
std
::
size_
t
k
=
0
;
k
<
K
;
++
k
)
{
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_element_op
(
a_m_k
(
m
,
k
))
:
a_element_op
(
a_m_k
(
k
,
m
));
BDataType
v_b
=
b_element_op
(
b_n_k
(
n
,
k
));
ADataType
v_a
=
a_element_op
(
a_m_k
(
m
,
k
));
BDataType
v_b
=
b_element_op
(
b_k_n
(
k
,
n
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
}
c_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
}
};
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
_mn
,
M
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
8bd49370
...
...
@@ -8,34 +8,29 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
,
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GemmKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
Layout
A
=
remove_cvref_t
<
Layout
A_
>
;
using
Layout
B
=
remove_cvref_t
<
Layout
B_
>
;
using
Layout
C
=
remove_cvref_t
<
Layout
C_
>
;
using
A
Layout
=
remove_cvref_t
<
typename
GemmPipeline
::
A
Layout
>
;
using
B
Layout
=
remove_cvref_t
<
typename
GemmPipeline
::
B
Layout
>
;
using
C
Layout
=
remove_cvref_t
<
typename
GemmPipeline
::
C
Layout
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
using
C
O
DataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
O
DataType
>
;
//
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
C
DataType
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M
_size
,
index_t
N
_size
,
index_t
Batch
_size
)
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
K
Batch
)
{
return
TilePartitioner
::
GridSize
(
M
_size
,
N_size
,
Batch
_size
);
return
TilePartitioner
::
GridSize
(
M
,
N
,
K
Batch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
...
...
@@ -45,30 +40,30 @@ struct GemmKernel
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
K
;
ck_tile
::
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
};
CK_TILE_HOST
static
constexpr
GemmCommonKargs
MakeKargs
(
const
void
*
a_ptr
,
const
void
*
b_ptr
,
void
*
c_ptr
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
K
,
ck_tile
::
index_t
stride_A
,
ck_tile
::
index_t
stride_B
,
ck_tile
::
index_t
stride_C
)
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
stride_C
)
{
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
ck_tile
::
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
...
...
@@ -79,12 +74,12 @@ struct GemmKernel
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
A
,
tensor_layout
::
gemm
::
Column
Major
>
)
if
constexpr
(
std
::
is_same_v
<
A
Layout
,
tensor_layout
::
gemm
::
Row
Major
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
AlignmentA
>
{},
number
<
1
>
{});
}
...
...
@@ -93,14 +88,14 @@ struct GemmKernel
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
GemmPipeline
::
AlignmentA
>
{},
number
<
1
>
{});
}
}();
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
B
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
B
Layout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
...
...
@@ -110,7 +105,7 @@ struct GemmKernel
number
<
1
>
{});
}
else
{
// Default NK layout
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
...
...
@@ -123,8 +118,8 @@ struct GemmKernel
auto
a_pad_view
=
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadA
?
1
:
0
>
{});
sequence
<
false
,
GemmPipeline
::
kPadA
?
true
:
false
>
{});
auto
ABlockWindow
=
make_tile_window
(
a_pad_view
,
...
...
@@ -134,8 +129,8 @@ struct GemmKernel
auto
b_pad_view
=
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadB
?
1
:
0
>
{});
sequence
<
false
,
GemmPipeline
::
kPadB
?
true
:
false
>
{});
auto
BBlockWindow
=
make_tile_window
(
b_pad_view
,
...
...
@@ -225,15 +220,15 @@ struct GemmKernel
}
}
C
O
DataType
*
c_start
=
static_cast
<
C
O
DataType
*>
(
kargs
.
c_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
C
,
tensor_layout
::
gemm
::
Column
Major
>
)
if
constexpr
(
std
::
is_same_v
<
C
Layout
,
tensor_layout
::
gemm
::
Row
Major
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
AlignmentC
>
{},
number
<
1
>
{});
}
...
...
@@ -242,7 +237,7 @@ struct GemmKernel
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
GemmPipeline
::
AlignmentC
>
{},
number
<
1
>
{});
}
...
...
@@ -251,13 +246,13 @@ struct GemmKernel
auto
c_pad_view
=
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadC
?
1
:
0
>
{});
auto
CBlockWindow
_pad
=
make_tile_window
(
sequence
<
false
,
GemmPipeline
::
kPadC
?
true
:
false
>
{});
auto
CBlockWindow
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow
_pad
,
c_block_tile
);
EpiloguePipeline
{}(
CBlockWindow
,
c_block_tile
);
}
};
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_mem.hpp
deleted
100644 → 0
View file @
d3689b06
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
BlockGemmPipelineAgBgCrDefaultPolicy
>
struct
BlockGemmPipelineAgBgCrMem
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
CBlockTile
=
typename
BlockGemm
::
CBlockTile
;
using
I0
=
number
<
0
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
AlignmentA
=
Problem
::
AlignmentA
;
static
constexpr
index_t
AlignmentB
=
Problem
::
AlignmentB
;
static
constexpr
index_t
AlignmentC
=
Problem
::
AlignmentC
;
static
constexpr
bool
kPadA
=
Problem
::
kPadA
;
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
get_warp_size
()
/
BlockSize
)
>=
1
?
4
*
get_warp_size
()
/
BlockSize
:
1
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
integer_divide_ceil
(
32768
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
?
FullMemBandPrefetchStages
<=
8
?
FullMemBandPrefetchStages
:
8
:
2
;
static
constexpr
index_t
LocalPrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
CK_TILE_HOST_DEVICE
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
CK_TILE_HOST_DEVICE
static
constexpr
TailNumber
GetBlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
return
TailNumber
::
One
;
}
else
if
(
num_loop
%
PrefetchStages
==
2
)
{
return
TailNumber
::
Two
;
}
else
if
(
num_loop
%
PrefetchStages
==
3
)
{
return
TailNumber
::
Three
;
}
else
if
(
num_loop
%
PrefetchStages
==
4
)
{
return
TailNumber
::
Four
;
}
else
if
(
num_loop
%
PrefetchStages
==
5
)
{
return
TailNumber
::
Five
;
}
else
if
(
num_loop
%
PrefetchStages
==
6
)
{
return
TailNumber
::
Six
;
}
else
if
(
num_loop
%
PrefetchStages
==
7
)
{
return
TailNumber
::
Seven
;
}
else
{
return
TailNumber
::
Full
;
}
}
CK_TILE_HOST_DEVICE
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
{
return
ck_tile
::
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
+
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
BlockGemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
{
};
template
<
>
struct
PipelineImpl
<
BlockGemmPipelineScheduler
::
Intrawave
>
{
template
<
typename
BlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
BlockTile
&
block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
{
load_tile_raw
(
block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
KPerBlock
});
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
CK_TILE_DEVICE
void
LocalPrefill
(
DstTileWindow
&
lds_tile_window
,
const
SrcBlockTile
&
src_block_tile
,
const
ElementFunction
&
element_func
)
const
{
const
auto
block_tile_tmp
=
tile_elementwise_in
(
element_func
,
src_block_tile
);
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
,
CBlockTile
&
c_block_tile
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// B LDS tile for block GEMM
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
constexpr
auto
block_gemm
=
BlockGemm
();
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
// prefetch
// global read 0
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
// Global prefetch [2, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
// main body
if
constexpr
(
HasHotLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
auto
HotLoopTail
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
});
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
};
// TODO: TailNumber2Number
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
HotLoopTail
(
number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
HotLoopTail
(
number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
HotLoopTail
(
number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
HotLoopTail
(
number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
HotLoopTail
(
number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
HotLoopTail
(
number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
HotLoopTail
(
number
<
PrefetchStages
>
{});
}
return
c_block_tile
;
}
};
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
,
CBlockTile
&
c_block_tile
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
a_element_func
,
b_dram_block_window_tmp
,
b_element_func
,
num_loop
,
p_smem
,
c_block_tile
);
}
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
index_t
num_loop
,
void
*
p_smem
,
CBlockTile
&
c_block_tile
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem
,
c_block_tile
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
View file @
8bd49370
...
...
@@ -13,6 +13,9 @@ template <typename ADataType_,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
,
bool
kPadA_
=
false
,
bool
kPadB_
=
false
,
bool
kPadC_
=
false
>
...
...
@@ -23,6 +26,10 @@ struct BlockGemmPipelineProblem
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
ALayout
=
remove_cvref_t
<
ALayout_
>
;
using
BLayout
=
remove_cvref_t
<
BLayout_
>
;
using
CLayout
=
remove_cvref_t
<
CLayout_
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
...
...
@@ -37,18 +44,29 @@ template <typename ADataType_,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
,
bool
kPadA_
=
false
,
bool
kPadB_
=
false
,
bool
kPadC_
=
false
,
BlockGemmPipelineScheduler
Scheduler_
=
BlockGemmPipelineScheduler
::
Intrawave
>
struct
BlockGemmUniversalPipelineProblem
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
bool
HasHotLoop_
=
false
,
TailNumber
TailNum_
=
TailNumber
::
Full
>
struct
UniversalGemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
ALayout
=
remove_cvref_t
<
ALayout_
>
;
using
BLayout
=
remove_cvref_t
<
BLayout_
>
;
using
CLayout
=
remove_cvref_t
<
CLayout_
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_ag_bg_cr_default_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_default_policy.hpp
View file @
8bd49370
...
...
@@ -4,12 +4,12 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp"
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV1
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
using
Block
GemmPipelineAgBgCrDefaultPolicy
=
Block
GemmPipelineAgBgCrMemCustomPolicy
;
using
GemmPipelineAgBgCrDefaultPolicy
=
GemmPipelineAgBgCrMemCustomPolicy
;
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp
View file @
8bd49370
...
...
@@ -9,14 +9,14 @@
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV1
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
struct
Block
GemmPipelineAgBgCrMemCustomPolicy
struct
GemmPipelineAgBgCrMemCustomPolicy
{
// 3d + padding
template
<
typename
Problem
>
...
...
@@ -47,8 +47,6 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
using
namespace
ck_tile
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
...
...
@@ -69,7 +67,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeA
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
...
...
@@ -77,7 +75,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeB
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
...
...
@@ -85,7 +83,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_ag_bg_cr_scheduler.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
View file @
8bd49370
...
...
@@ -3,9 +3,11 @@
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
enum
struct
Block
GemmPipelineScheduler
enum
struct
GemmPipelineScheduler
{
Intrawave
,
Interwave
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment