Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
43adf1fa
Commit
43adf1fa
authored
Dec 13, 2023
by
Harisankar Sadasivan
Browse files
clang format
parent
ab3d3b4a
Changes
4
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1154 additions
and
1014 deletions
+1154
-1014
example/54_tall_and_skinny_gemm_splitk/run_tall_and_skinny_gemm_splitk_example.inc
...y_gemm_splitk/run_tall_and_skinny_gemm_splitk_example.inc
+0
-0
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+15
-16
include/ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
...on/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
+438
-337
include/ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
...eration/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
+701
-661
No files found.
example/54_tall_and_skinny_gemm_splitk/run_tall_and_skinny_gemm_splitk_example.inc
100644 → 100755
View file @
43adf1fa
File mode changed from 100644 to 100755
include/ck/host_utility/kernel_launch.hpp
View file @
43adf1fa
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
#ifndef KERNARG_PRELOAD
#ifndef KERNARG_PRELOAD
template
<
typename
...
Args
,
typename
F
>
template
<
typename
...
Args
,
typename
F
>
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
F
kernel
,
F
kernel
,
dim3
grid_dim
,
dim3
grid_dim
,
dim3
block_dim
,
dim3
block_dim
,
...
@@ -19,7 +19,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
...
@@ -19,7 +19,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
Args
...
args
)
Args
...
args
)
{
{
#if CK_TIME_KERNEL
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
if
(
stream_config
.
time_kernel_
)
{
{
#if DEBUG_LOG
#if DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
...
@@ -49,7 +49,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
...
@@ -49,7 +49,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
...
@@ -81,7 +81,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
...
@@ -81,7 +81,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
#else
#else
template
<
typename
...
Args
,
typename
F
>
template
<
typename
...
Args
,
typename
F
>
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
F
kernel
,
F
kernel
,
dim3
grid_dim
,
dim3
grid_dim
,
dim3
block_dim
,
dim3
block_dim
,
...
@@ -92,7 +92,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
...
@@ -92,7 +92,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
// hipGetErrorString(hipMalloc(&args1, sizeof(Args)));
// hipGetErrorString(hipMalloc(&args1, sizeof(Args)));
// hip_check_error(hipMemcpy(args1, &args, sizeof(Args), hipMemcpyHostToDevice));
// hip_check_error(hipMemcpy(args1, &args, sizeof(Args), hipMemcpyHostToDevice));
#if CK_TIME_KERNEL
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
if
(
stream_config
.
time_kernel_
)
{
{
#if DEBUG_LOG
#if DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
...
@@ -109,9 +109,9 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
...
@@ -109,9 +109,9 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
//
//
// warm up
// warm up
const
int
nrepeat
=
1000
;
const
int
nrepeat
=
1000
;
for
(
auto
i
=
0
;
i
<
nrepeat
;
i
++
)
for
(
auto
i
=
0
;
i
<
nrepeat
;
i
++
)
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
hipLaunchKernelGGL
(
args
...);
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
#if DEBUG_LOG
#if DEBUG_LOG
...
@@ -127,9 +127,9 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
...
@@ -127,9 +127,9 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
hipLaunchKernelGGL
(
args
...);
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
// hip_check_error(hipGetLastError());
// hip_check_error(hipGetLastError());
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
...
@@ -140,8 +140,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
...
@@ -140,8 +140,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
}
}
else
else
{
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
args
...);
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
return
0
;
return
0
;
...
@@ -155,7 +154,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
...
@@ -155,7 +154,7 @@ float launch_and_time_kernel(const StreamConfig &stream_config,
}
}
#endif
#endif
template
<
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
template
<
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
PreProcessFunc
preprocess
,
PreProcessFunc
preprocess
,
F
kernel
,
F
kernel
,
dim3
grid_dim
,
dim3
grid_dim
,
...
@@ -164,7 +163,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config,
...
@@ -164,7 +163,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config,
Args
...
args
)
Args
...
args
)
{
{
#if CK_TIME_KERNEL
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
if
(
stream_config
.
time_kernel_
)
{
{
#if DEBUG_LOG
#if DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
...
@@ -195,7 +194,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config,
...
@@ -195,7 +194,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config,
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
preprocess
();
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
...
...
include/ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
View file @
43adf1fa
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
View file @
43adf1fa
...
@@ -16,10 +16,9 @@
...
@@ -16,10 +16,9 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
namespace
ck
{
{
template
<
typename
GridwiseTsmm
,
template
<
typename
GridwiseTsmm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
BLayout
,
typename
BLayout
,
...
@@ -27,35 +26,68 @@ namespace ck
...
@@ -27,35 +26,68 @@ namespace ck
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
typename
Block2CTileMap
>
typename
Block2CTileMap
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_tsmm_dl_v1r3
(
kernel_tsmm_dl_v1r3
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
const
FloatAB
*
p_a_grid
,
index_t
K0
,
index_t
k_batch
,
index_t
MPadded
,
index_t
NPadded
,
const
Block2CTileMap
block_2_ctile_map
)
//: in __global__ functions, struct is
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
K0
,
index_t
k_batch
,
index_t
MPadded
,
index_t
NPadded
,
const
Block2CTileMap
block_2_ctile_map
)
//: in __global__ functions, struct is
// better for reduced load overhead
// better for reduced load overhead
{
{
// strides depend on B's layout
// strides depend on B's layout
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
HasDoubleTailKBlockLoop
,
HasDoubleTailKBlockLoop
,
GridwiseTsmm
,
GridwiseTsmm
,
CGlobalMemoryDataOperation
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
CGlobalMemoryDataOperation
>(
p_a_grid
,
K0
,
k_batch
,
K
,
N
,
N
,
MPadded
,
NPadded
,
block_2_ctile_map
);
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
K0
,
k_batch
,
K
,
N
,
N
,
MPadded
,
NPadded
,
block_2_ctile_map
);
}
}
else
else
{
{
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
HasDoubleTailKBlockLoop
,
HasDoubleTailKBlockLoop
,
GridwiseTsmm
,
GridwiseTsmm
,
CGlobalMemoryDataOperation
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
CGlobalMemoryDataOperation
>(
p_a_grid
,
K0
,
k_batch
,
K
,
K
,
N
,
MPadded
,
NPadded
,
block_2_ctile_map
);
p_b_grid
,
}
p_c_grid
,
}
M
,
N
,
template
<
index_t
BlockSize
,
K
,
K0
,
k_batch
,
K
,
K
,
N
,
MPadded
,
NPadded
,
block_2_ctile_map
);
}
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
...
@@ -83,8 +115,8 @@ namespace ck
...
@@ -83,8 +115,8 @@ namespace ck
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseTsmmDl_km_kn_mn
struct
GridwiseTsmmDl_km_kn_mn
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -96,9 +128,9 @@ namespace ck
...
@@ -96,9 +128,9 @@ namespace ck
// Argument
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
//
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
//
{
{
Argument
(
const
FloatAB
*
p_a_grid_
,
Argument
(
const
FloatAB
*
p_a_grid_
,
const
FloatAB
*
p_b_grid_
,
const
FloatAB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
index_t
M_
,
index_t
N_
,
index_t
N_
,
index_t
K_
,
index_t
K_
,
...
@@ -128,9 +160,9 @@ namespace ck
...
@@ -128,9 +160,9 @@ namespace ck
}
}
// private:
// private:
const
FloatAB
*
p_a_grid
;
const
FloatAB
*
p_a_grid
;
const
FloatAB
*
p_b_grid
;
const
FloatAB
*
p_b_grid
;
FloatC
*
p_c_grid
;
FloatC
*
p_c_grid
;
index_t
M
,
N
,
K
;
index_t
M
,
N
,
K
;
index_t
StrideA
,
StrideB
,
StrideC
;
index_t
StrideA
,
StrideB
,
StrideC
;
...
@@ -214,19 +246,18 @@ namespace ck
...
@@ -214,19 +246,18 @@ namespace ck
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
StrideA
,
index_t
KBatch
,
index_t
K0
)
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
StrideA
,
index_t
KBatch
,
index_t
K0
)
{
{
const
auto
a_grid_desc_m_k
=
[
&
]()
const
auto
a_grid_desc_m_k
=
[
&
]()
{
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}
}();
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -255,19 +286,18 @@ namespace ck
...
@@ -255,19 +286,18 @@ namespace ck
index_t
K
,
index_t
NPad
,
index_t
N
,
index_t
StrideB
,
index_t
KBatch
,
index_t
K0
)
index_t
K
,
index_t
NPad
,
index_t
N
,
index_t
StrideB
,
index_t
KBatch
,
index_t
K0
)
{
{
const
auto
b_grid_desc_k_n
=
[
&
]()
const
auto
b_grid_desc_k_n
=
[
&
]()
{
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}
}();
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -290,19 +320,18 @@ namespace ck
...
@@ -290,19 +320,18 @@ namespace ck
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
{
const
auto
c_grid_desc_m_n
=
[
&
]()
const
auto
c_grid_desc_m_n
=
[
&
]()
{
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}
}();
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
...
@@ -335,7 +364,7 @@ namespace ck
...
@@ -335,7 +364,7 @@ namespace ck
using
BGridDesc_Kbatch_K0_N_K1
=
decltype
(
MakeBGridDescriptor_KBatch_K0_N_K1
(
1
,
1
,
1
,
1
,
1
,
1
));
using
BGridDesc_Kbatch_K0_N_K1
=
decltype
(
MakeBGridDescriptor_KBatch_K0_N_K1
(
1
,
1
,
1
,
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
{
// const auto MPadded = CalculateMPadded(karg.M);
// const auto MPadded = CalculateMPadded(karg.M);
...
@@ -361,7 +390,7 @@ namespace ck
...
@@ -361,7 +390,7 @@ namespace ck
// KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1
// KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
const
AGridDesc_Kbatch_K0_M_K1
&
a_grid_desc_kbatch_k0_m_k1
)
const
AGridDesc_Kbatch_K0_M_K1
&
a_grid_desc_kbatch_k0_m_k1
)
{
{
const
auto
KBatch
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I0
);
const
auto
KBatch
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I1
);
...
@@ -383,7 +412,7 @@ namespace ck
...
@@ -383,7 +412,7 @@ namespace ck
}
}
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_Kbatch_K0_N0_N1_K1
(
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_Kbatch_K0_N0_N1_K1
(
const
BGridDesc_Kbatch_K0_N_K1
&
b_grid_desc_kbatch_k0_n_k1
)
const
BGridDesc_Kbatch_K0_N_K1
&
b_grid_desc_kbatch_k0_n_k1
)
{
{
const
auto
KBatch
=
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I0
);
const
auto
KBatch
=
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I0
);
const
auto
K0
=
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I1
);
...
@@ -405,7 +434,7 @@ namespace ck
...
@@ -405,7 +434,7 @@ namespace ck
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
...
@@ -451,8 +480,20 @@ namespace ck
...
@@ -451,8 +480,20 @@ namespace ck
bool
HasDoubleTailKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
typename
GridwiseTsmm
,
typename
GridwiseTsmm
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__device__
static
void
Run
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
__device__
static
void
Run
(
const
FloatAB
*
p_a_grid
,
index_t
K0
,
index_t
k_batch
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
MPadded
,
index_t
NPadded
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
K0
,
index_t
k_batch
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
MPadded
,
index_t
NPadded
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
...
@@ -464,8 +505,7 @@ namespace ck
...
@@ -464,8 +505,7 @@ namespace ck
M
,
MPadded
,
K
,
StrideA
,
k_batch
,
K0
);
//
M
,
MPadded
,
K
,
StrideA
,
k_batch
,
K0
);
//
const
auto
b_grid_desc_kbatch_k0_n_k1
=
GridwiseTsmm
::
MakeBGridDescriptor_KBatch_K0_N_K1
(
const
auto
b_grid_desc_kbatch_k0_n_k1
=
GridwiseTsmm
::
MakeBGridDescriptor_KBatch_K0_N_K1
(
K
,
NPadded
,
N
,
StrideB
,
k_batch
,
K0
);
//
K
,
NPadded
,
N
,
StrideB
,
k_batch
,
K0
);
//
const
auto
c_grid_desc_m_n
=
const
auto
c_grid_desc_m_n
=
GridwiseTsmm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
GridwiseTsmm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
auto
a_grid_desc_kbatch_k0_m0_m1_k1
=
const
auto
a_grid_desc_kbatch_k0_m0_m1_k1
=
GridwiseTsmm
::
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1
);
//
GridwiseTsmm
::
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1
);
//
...
@@ -482,15 +522,15 @@ namespace ck
...
@@ -482,15 +522,15 @@ namespace ck
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
const
auto
c_m0_n0_block_cluster_idx
=
block_2_ctile_map
.
convert_1D_block_idx_to_3D_tuple
(
const
auto
c_m0_n0_block_cluster_idx
=
get_block_1d_id
(),
N
,
k_batch
);
block_2_ctile_map
.
convert_1D_block_idx_to_3D_tuple
(
get_block_1d_id
(),
N
,
k_batch
);
// HACK: this force index data into SGPR
// HACK: this force index data into SGPR
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
const
index_t
in0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I1
]);
const
index_t
in0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I1
]);
const
index_t
kbatch_id
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I2
]);
const
index_t
kbatch_id
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I2
]);
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
make_tuple
(
im0
,
in0
),
make_tuple
(
im0
,
in0
),
make_tuple
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
),
make_tuple
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
),
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I3
))))
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I3
))))
...
@@ -593,7 +633,7 @@ namespace ck
...
@@ -593,7 +633,7 @@ namespace ck
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block_double
=
p_shared_block
;
FloatAB
*
p_a_block_double
=
p_shared_block
;
auto
b_thread_odd_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
auto
b_thread_odd_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_k0_n_k1_thread_desc
.
GetElementSpaceSize
());
b_k0_n_k1_thread_desc
.
GetElementSpaceSize
());
...
@@ -632,7 +672,7 @@ namespace ck
...
@@ -632,7 +672,7 @@ namespace ck
b_thread_even_buf
);
b_thread_even_buf
);
}
}
if
constexpr
(
HasMainKBlockLoop
)
if
constexpr
(
HasMainKBlockLoop
)
{
{
// const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
// const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
...
@@ -691,11 +731,11 @@ namespace ck
...
@@ -691,11 +731,11 @@ namespace ck
a_blockwise_copy
.
RunWrite
(
a_block_desc_copy_kbatch_k0_m0_m1_k1
,
a_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_copy_kbatch_k0_m0_m1_k1
,
a_block_even_buf
);
k_block_data_begin
+=
2
*
K0PerBlock
;
k_block_data_begin
+=
2
*
K0PerBlock
;
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
}
}
// LDS double buffer: tail
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
a_block_slice_copy_step
);
a_block_slice_copy_step
);
...
@@ -780,5 +820,5 @@ namespace ck
...
@@ -780,5 +820,5 @@ namespace ck
c_grid_buf
);
c_grid_buf
);
}
}
}
}
};
};
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment