Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f88c2f86
Commit
f88c2f86
authored
Dec 13, 2023
by
Harisankar Sadasivan
Browse files
kernarg load latency optimization for mi300
parent
c2784145
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1085 additions
and
976 deletions
+1085
-976
example/53_gemv_splitk/CMakeLists.txt
example/53_gemv_splitk/CMakeLists.txt
+1
-0
example/54_tall_and_skinny_gemm_splitk/CMakeLists.txt
example/54_tall_and_skinny_gemm_splitk/CMakeLists.txt
+1
-0
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+82
-6
include/ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
...on/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
+337
-318
include/ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
...eration/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
+664
-652
No files found.
example/53_gemv_splitk/CMakeLists.txt
View file @
f88c2f86
...
@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_gemv_splitk_fp16 gemv_splitk_fp16.cpp
)
add_example_executable
(
example_gemv_splitk_fp16 gemv_splitk_fp16.cpp
)
add_dependencies
(
example_gemv_splitk
add_dependencies
(
example_gemv_splitk
example_gemv_splitk_fp16
)
example_gemv_splitk_fp16
)
set_source_files_properties
(
gemv_splitk_fp16.cpp PROPERTIES COMPILE_OPTIONS
"-DKERNARG_PRELOAD;-Wno-gnu-line-marker;-gline-tables-only;-mllvm;--amdgpu-kernarg-preload-count=16"
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
example/54_tall_and_skinny_gemm_splitk/CMakeLists.txt
View file @
f88c2f86
...
@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp
)
add_example_executable
(
example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp
)
add_dependencies
(
example_tall_and_skinny_gemm_splitk
add_dependencies
(
example_tall_and_skinny_gemm_splitk
example_tall_and_skinny_gemm_splitk_fp16
)
example_tall_and_skinny_gemm_splitk_fp16
)
set_source_files_properties
(
tall_and_skinny_gemm_splitk_fp16.cpp PROPERTIES COMPILE_OPTIONS
"-DKERNARG_PRELOAD;-Wno-gnu-line-marker;-gline-tables-only;-mllvm;--amdgpu-kernarg-preload-count=16"
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
include/ck/host_utility/kernel_launch.hpp
View file @
f88c2f86
...
@@ -9,8 +9,9 @@
...
@@ -9,8 +9,9 @@
#include "ck/stream_config.hpp"
#include "ck/stream_config.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#ifndef KERNARG_PRELOAD
template
<
typename
...
Args
,
typename
F
>
template
<
typename
...
Args
,
typename
F
>
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
F
kernel
,
F
kernel
,
dim3
grid_dim
,
dim3
grid_dim
,
dim3
block_dim
,
dim3
block_dim
,
...
@@ -18,7 +19,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -18,7 +19,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
Args
...
args
)
Args
...
args
)
{
{
#if CK_TIME_KERNEL
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
if
(
stream_config
.
time_kernel_
)
{
{
#if DEBUG_LOG
#if DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
...
@@ -48,7 +49,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -48,7 +49,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
...
@@ -78,8 +79,83 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -78,8 +79,83 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#endif
#endif
}
}
#else
template
<
typename
...
Args
,
typename
F
>
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
// Args* args1;
// hipGetErrorString(hipMalloc(&args1, sizeof(Args)));
// hip_check_error(hipMemcpy(args1, &args, sizeof(Args), hipMemcpyHostToDevice));
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
{
#if DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up 1 time
\n
"
);
#endif
//
// warm up
const
int
nrepeat
=
1000
;
for
(
auto
i
=
0
;
i
<
nrepeat
;
i
++
)
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
hip_check_error
(
hipGetLastError
());
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
hipEvent_t
start
,
stop
;
float
total_time
=
0
;
hip_check_error
(
hipEventCreate
(
&
start
));
hip_check_error
(
hipEventCreate
(
&
stop
));
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
// hip_check_error(hipGetLastError());
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventSynchronize
(
stop
));
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
return
total_time
/
nrepeat
;
}
else
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
#endif
}
#endif
template
<
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
template
<
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
PreProcessFunc
preprocess
,
PreProcessFunc
preprocess
,
F
kernel
,
F
kernel
,
dim3
grid_dim
,
dim3
grid_dim
,
...
@@ -88,7 +164,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -88,7 +164,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
Args
...
args
)
Args
...
args
)
{
{
#if CK_TIME_KERNEL
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
if
(
stream_config
.
time_kernel_
)
{
{
#if DEBUG_LOG
#if DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
...
@@ -119,7 +195,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -119,7 +195,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
preprocess
();
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
...
...
include/ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
View file @
f88c2f86
...
@@ -16,11 +16,14 @@
...
@@ -16,11 +16,14 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
ck
namespace
tensor_operation
{
{
namespace
device
{
namespace
tensor_operation
{
namespace
device
{
template
<
template
<
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
...
@@ -58,7 +61,7 @@ template <
...
@@ -58,7 +61,7 @@ template <
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
&&
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
&&
is_same_v
<
CElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
,
is_same_v
<
CElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
,
bool
>
=
false
>
bool
>
=
false
>
struct
deviceTsmmDl
:
public
DeviceTsmm
<
ALayout
,
struct
deviceTsmmDl
:
public
DeviceTsmm
<
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
ADataType
,
ADataType
,
...
@@ -68,7 +71,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -68,7 +71,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -113,11 +116,11 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -113,11 +116,11 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
float
Run
(
const
Argument
&
karg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
karg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
const
index_t
grid_size
=
GridwiseTsmm
::
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
const
index_t
grid_size
=
GridwiseTsmm
::
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
//
const auto b2c_map
= DefaultBlock2CTileMap{};
const
auto
b2c_map
=
DefaultBlock2CTileMap
{};
const
auto
K0
=
karg
.
K0
;
const
auto
K0
=
karg
.
K0
;
...
@@ -127,128 +130,144 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -127,128 +130,144 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
karg
.
k_batch
>
1
)
if
(
karg
.
k_batch
>
1
)
hipGetErrorString
(
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
)));
hipGetErrorString
(
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
)));
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
if
(
karg
.
k_batch
==
1
)
if
(
karg
.
k_batch
==
1
)
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
true
,
true
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
true
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
if
(
karg
.
k_batch
==
1
)
if
(
karg
.
k_batch
==
1
)
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
true
,
true
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
true
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
if
(
karg
.
k_batch
==
1
)
if
(
karg
.
k_batch
==
1
)
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
false
,
false
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
false
,
false
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
}
}
else
else
{
{
if
(
karg
.
k_batch
==
1
)
if
(
karg
.
k_batch
==
1
)
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
false
,
false
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
false
,
false
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
}
}
return
ave_time
;
return
ave_time
;
}
}
// polymorphic
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
float
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
@@ -258,12 +277,12 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -258,12 +277,12 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
return
true
;
return
true
;
}
}
// //
// //
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
||
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
||
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
ck
::
get_device_name
()
==
"gfx1102"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
)
{
{
return
GridwiseTsmm
::
CheckValidity
(
arg
);
return
GridwiseTsmm
::
CheckValidity
(
arg
);
}
}
...
@@ -274,14 +293,14 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -274,14 +293,14 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
}
}
// //
// //
// polymorphic
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
CDataType
*
p_c
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -302,8 +321,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -302,8 +321,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
//
GridwiseTsmm::CalculateMPadded(M),
GridwiseTsmm
::
CalculateMPadded
(
M
),
//
GridwiseTsmm::CalculateNPadded(N),
GridwiseTsmm
::
CalculateNPadded
(
N
),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
KBatch
};
// //
KBatch
};
// //
...
@@ -312,9 +331,9 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -312,9 +331,9 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
void
*
p_c
,
void
*
p_c
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -327,17 +346,17 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -327,17 +346,17 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
ck
::
index_t
KBatch
=
1
)
override
// //
ck
::
index_t
KBatch
=
1
)
override
// //
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
M
,
N
,
N
,
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
//
GridwiseTsmm::CalculateMPadded(M),
GridwiseTsmm
::
CalculateMPadded
(
M
),
//
GridwiseTsmm::CalculateNPadded(N),
GridwiseTsmm
::
CalculateNPadded
(
N
),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
KBatch
);
// //
KBatch
);
// //
...
@@ -370,8 +389,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -370,8 +389,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
return
str
.
str
();
return
str
.
str
();
}
}
};
};
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
View file @
f88c2f86
...
@@ -16,31 +16,46 @@
...
@@ -16,31 +16,46 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseTsmm
,
template
<
typename
GridwiseTsmm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
BLayout
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
typename
Block2CTileMap
>
typename
Block2CTileMap
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_tsmm_dl_v1r3
(
kernel_tsmm_dl_v1r3
(
typename
GridwiseTsmm
::
Argument
karg
)
//: in __global__ functions, struct is
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
K0
,
index_t
k_batch
,
index_t
MPadded
,
index_t
NPadded
,
const
Block2CTileMap
block_2_ctile_map
)
//: in __global__ functions, struct is
// better for reduced load overhead
// better for reduced load overhead
{
{
// strides depend on B's layout
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
HasDoubleTailKBlockLoop
,
GridwiseTsmm
,
CGlobalMemoryDataOperation
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
K0
,
k_batch
,
K
,
N
,
N
,
MPadded
,
NPadded
,
block_2_ctile_map
);
}
else
{
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
HasDoubleTailKBlockLoop
,
HasDoubleTailKBlockLoop
,
GridwiseTsmm
,
GridwiseTsmm
,
CGlobalMemoryDataOperation
>(
karg
);
CGlobalMemoryDataOperation
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
}
K0
,
k_batch
,
K
,
K
,
N
,
MPadded
,
NPadded
,
block_2_ctile_map
);
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
...
@@ -68,8 +83,8 @@ template <index_t BlockSize,
...
@@ -68,8 +83,8 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseTsmmDl_km_kn_mn
struct
GridwiseTsmmDl_km_kn_mn
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -81,17 +96,17 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -81,17 +96,17 @@ struct GridwiseTsmmDl_km_kn_mn
// Argument
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
//
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
//
{
{
Argument
(
const
FloatAB
*
p_a_grid_
,
Argument
(
const
FloatAB
*
p_a_grid_
,
const
FloatAB
*
p_b_grid_
,
const
FloatAB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
index_t
M_
,
index_t
N_
,
index_t
N_
,
index_t
K_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
StrideC_
,
//
index_t MPadded_,
index_t
MPadded_
,
//
index_t NPadded_,
index_t
NPadded_
,
// index_t KPadded_,
// index_t KPadded_,
index_t
K0_
,
index_t
K0_
,
index_t
k_batch_
)
index_t
k_batch_
)
...
@@ -104,8 +119,8 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -104,8 +119,8 @@ struct GridwiseTsmmDl_km_kn_mn
StrideA
{
StrideA_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
StrideC
{
StrideC_
},
//
MPadded(MPadded_),
MPadded
(
MPadded_
),
//
NPadded(NPadded_),
NPadded
(
NPadded_
),
// KPadded(KPadded_),
// KPadded(KPadded_),
K0
(
K0_
),
K0
(
K0_
),
k_batch
(
k_batch_
)
k_batch
(
k_batch_
)
...
@@ -113,15 +128,15 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -113,15 +128,15 @@ struct GridwiseTsmmDl_km_kn_mn
}
}
// private:
// private:
const
FloatAB
*
p_a_grid
;
const
FloatAB
*
p_a_grid
;
const
FloatAB
*
p_b_grid
;
const
FloatAB
*
p_b_grid
;
FloatC
*
p_c_grid
;
FloatC
*
p_c_grid
;
index_t
M
,
N
,
K
;
index_t
M
,
N
,
K
;
index_t
StrideA
,
StrideB
,
StrideC
;
index_t
StrideA
,
StrideB
,
StrideC
;
//:
//:
//
index_t MPadded;
index_t
MPadded
;
//
index_t NPadded;
index_t
NPadded
;
// index_t KPadded;
// index_t KPadded;
index_t
K0
;
index_t
K0
;
index_t
k_batch
;
index_t
k_batch
;
...
@@ -199,18 +214,19 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -199,18 +214,19 @@ struct GridwiseTsmmDl_km_kn_mn
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
StrideA
,
index_t
KBatch
,
index_t
K0
)
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
StrideA
,
index_t
KBatch
,
index_t
K0
)
{
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
const
auto
a_grid_desc_m_k
=
[
&
]()
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}
}();
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -239,18 +255,19 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -239,18 +255,19 @@ struct GridwiseTsmmDl_km_kn_mn
index_t
K
,
index_t
NPad
,
index_t
N
,
index_t
StrideB
,
index_t
KBatch
,
index_t
K0
)
index_t
K
,
index_t
NPad
,
index_t
N
,
index_t
StrideB
,
index_t
KBatch
,
index_t
K0
)
{
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
const
auto
b_grid_desc_k_n
=
[
&
]()
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}
}();
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -273,18 +290,19 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -273,18 +290,19 @@ struct GridwiseTsmmDl_km_kn_mn
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
const
auto
c_grid_desc_m_n
=
[
&
]()
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}
}();
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
...
@@ -317,15 +335,15 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -317,15 +335,15 @@ struct GridwiseTsmmDl_km_kn_mn
using
BGridDesc_Kbatch_K0_N_K1
=
decltype
(
MakeBGridDescriptor_KBatch_K0_N_K1
(
1
,
1
,
1
,
1
,
1
,
1
));
using
BGridDesc_Kbatch_K0_N_K1
=
decltype
(
MakeBGridDescriptor_KBatch_K0_N_K1
(
1
,
1
,
1
,
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
{
const
auto
MPadded
=
CalculateMPadded
(
karg
.
M
);
//
const auto MPadded = CalculateMPadded(karg.M);
const
auto
NPadded
=
CalculateNPadded
(
karg
.
N
);
//
const auto NPadded = CalculateNPadded(karg.N);
const
auto
a_grid_desc_kbatch_k0_m_k1
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
const
auto
a_grid_desc_kbatch_k0_m_k1
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
);
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
);
const
auto
b_grid_desc_kbatch_k0_n_k1
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
const
auto
b_grid_desc_kbatch_k0_n_k1
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
);
karg
.
K
,
karg
.
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
KBatch_a
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I0
);
const
auto
KBatch_a
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I0
);
...
@@ -343,7 +361,7 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -343,7 +361,7 @@ struct GridwiseTsmmDl_km_kn_mn
// KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1
// KBatch, K0, M, K1 -> KBatch, K0, M0, M1 (MPerBlock), K1
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
const
AGridDesc_Kbatch_K0_M_K1
&
a_grid_desc_kbatch_k0_m_k1
)
const
AGridDesc_Kbatch_K0_M_K1
&
a_grid_desc_kbatch_k0_m_k1
)
{
{
const
auto
KBatch
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I0
);
const
auto
KBatch
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I1
);
...
@@ -365,7 +383,7 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -365,7 +383,7 @@ struct GridwiseTsmmDl_km_kn_mn
}
}
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_Kbatch_K0_N0_N1_K1
(
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_Kbatch_K0_N0_N1_K1
(
const
BGridDesc_Kbatch_K0_N_K1
&
b_grid_desc_kbatch_k0_n_k1
)
const
BGridDesc_Kbatch_K0_N_K1
&
b_grid_desc_kbatch_k0_n_k1
)
{
{
const
auto
KBatch
=
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I0
);
const
auto
KBatch
=
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I0
);
const
auto
K0
=
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
b_grid_desc_kbatch_k0_n_k1
.
GetLength
(
I1
);
...
@@ -387,7 +405,7 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -387,7 +405,7 @@ struct GridwiseTsmmDl_km_kn_mn
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
...
@@ -433,27 +451,21 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -433,27 +451,21 @@ struct GridwiseTsmmDl_km_kn_mn
bool
HasDoubleTailKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
typename
GridwiseTsmm
,
typename
GridwiseTsmm
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__device__
static
void
Run
(
const
Argument
&
karg
)
__device__
static
void
Run
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
K0
,
index_t
k_batch
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
MPadded
,
index_t
NPadded
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseTsmm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseTsmm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
const
Block2CTileMap
&
block_2_ctile_map
=
Block2CTileMap
{};
const
auto
MPadded
=
CalculateMPadded
(
karg
.
M
);
const
auto
NPadded
=
CalculateNPadded
(
karg
.
N
);
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
const
auto
a_grid_desc_kbatch_k0_m_k1
=
GridwiseTsmm
::
MakeAGridDescriptor_KBatch_K0_M_K1
(
const
auto
a_grid_desc_kbatch_k0_m_k1
=
GridwiseTsmm
::
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
);
//
M
,
MPadded
,
K
,
StrideA
,
k_batch
,
K0
);
//
const
auto
b_grid_desc_kbatch_k0_n_k1
=
GridwiseTsmm
::
MakeBGridDescriptor_KBatch_K0_N_K1
(
const
auto
b_grid_desc_kbatch_k0_n_k1
=
GridwiseTsmm
::
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
);
//
K
,
NPadded
,
N
,
StrideB
,
k_batch
,
K0
);
//
const
auto
c_grid_desc_m_n
=
const
auto
c_grid_desc_m_n
=
GridwiseTsmm
::
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
GridwiseTsmm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
auto
a_grid_desc_kbatch_k0_m0_m1_k1
=
const
auto
a_grid_desc_kbatch_k0_m0_m1_k1
=
GridwiseTsmm
::
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1
);
//
GridwiseTsmm
::
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1
);
//
...
@@ -471,14 +483,14 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -471,14 +483,14 @@ struct GridwiseTsmmDl_km_kn_mn
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
const
auto
c_m0_n0_block_cluster_idx
=
block_2_ctile_map
.
convert_1D_block_idx_to_3D_tuple
(
const
auto
c_m0_n0_block_cluster_idx
=
block_2_ctile_map
.
convert_1D_block_idx_to_3D_tuple
(
get_block_1d_id
(),
karg
.
N
,
karg
.
k_batch
);
get_block_1d_id
(),
N
,
k_batch
);
// HACK: this force index data into SGPR
// HACK: this force index data into SGPR
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
const
index_t
in0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I1
]);
const
index_t
in0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I1
]);
const
index_t
kbatch_id
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I2
]);
const
index_t
kbatch_id
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I2
]);
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
make_tuple
(
im0
,
in0
),
make_tuple
(
im0
,
in0
),
make_tuple
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
),
make_tuple
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
),
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I3
))))
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I3
))))
...
@@ -581,7 +593,7 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -581,7 +593,7 @@ struct GridwiseTsmmDl_km_kn_mn
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block_double
=
p_shared_block
;
FloatAB
*
p_a_block_double
=
p_shared_block
;
auto
b_thread_odd_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
auto
b_thread_odd_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_k0_n_k1_thread_desc
.
GetElementSpaceSize
());
b_k0_n_k1_thread_desc
.
GetElementSpaceSize
());
...
@@ -620,9 +632,9 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -620,9 +632,9 @@ struct GridwiseTsmmDl_km_kn_mn
b_thread_even_buf
);
b_thread_even_buf
);
}
}
if
constexpr
(
HasMainKBlockLoop
)
if
constexpr
(
HasMainKBlockLoop
)
{
{
const
auto
K0
=
a_grid_desc_kbatch_k0_m0_m1_k1
.
GetLength
(
I1
);
//
const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
index_t
k_block_data_begin
=
0
;
index_t
k_block_data_begin
=
0
;
...
@@ -679,11 +691,11 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -679,11 +691,11 @@ struct GridwiseTsmmDl_km_kn_mn
a_blockwise_copy
.
RunWrite
(
a_block_desc_copy_kbatch_k0_m0_m1_k1
,
a_block_even_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_copy_kbatch_k0_m0_m1_k1
,
a_block_even_buf
);
k_block_data_begin
+=
2
*
K0PerBlock
;
k_block_data_begin
+=
2
*
K0PerBlock
;
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
}
}
// LDS double buffer: tail
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
a_block_slice_copy_step
);
a_block_slice_copy_step
);
...
@@ -768,5 +780,5 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -768,5 +780,5 @@ struct GridwiseTsmmDl_km_kn_mn
c_grid_buf
);
c_grid_buf
);
}
}
}
}
};
};
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment