Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8861bd66
"test/vscode:/vscode.git/clone" did not exist on "f22f38f7f260aad80b66f5e08fa3aaab2b0f7b49"
Commit
8861bd66
authored
Dec 13, 2023
by
Harisankar Sadasivan
Browse files
kernarg load latency optimization for mi300
parent
e42b36ee
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
325 additions
and
74 deletions
+325
-74
example/53_gemv_splitk/CMakeLists.txt
example/53_gemv_splitk/CMakeLists.txt
+1
-0
example/54_tall_and_skinny_gemm_splitk/CMakeLists.txt
example/54_tall_and_skinny_gemm_splitk/CMakeLists.txt
+1
-0
example/54_tall_and_skinny_gemm_splitk/run_tall_and_skinny_gemm_splitk_example.inc
...y_gemm_splitk/run_tall_and_skinny_gemm_splitk_example.inc
+15
-13
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+75
-0
include/ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
...on/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
+143
-23
include/ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
...eration/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
+90
-38
No files found.
example/53_gemv_splitk/CMakeLists.txt
View file @
8861bd66
...
@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_gemv_splitk_fp16 gemv_splitk_fp16.cpp
)
add_example_executable
(
example_gemv_splitk_fp16 gemv_splitk_fp16.cpp
)
add_dependencies
(
example_gemv_splitk
add_dependencies
(
example_gemv_splitk
example_gemv_splitk_fp16
)
example_gemv_splitk_fp16
)
set_source_files_properties
(
gemv_splitk_fp16.cpp PROPERTIES COMPILE_OPTIONS
"-DKERNARG_PRELOAD;-Wno-gnu-line-marker;-gline-tables-only;-mllvm;--amdgpu-kernarg-preload-count=16"
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
example/54_tall_and_skinny_gemm_splitk/CMakeLists.txt
View file @
8861bd66
...
@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp
)
add_example_executable
(
example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp
)
add_dependencies
(
example_tall_and_skinny_gemm_splitk
add_dependencies
(
example_tall_and_skinny_gemm_splitk
example_tall_and_skinny_gemm_splitk_fp16
)
example_tall_and_skinny_gemm_splitk_fp16
)
set_source_files_properties
(
tall_and_skinny_gemm_splitk_fp16.cpp PROPERTIES COMPILE_OPTIONS
"-DKERNARG_PRELOAD;-Wno-gnu-line-marker;-gline-tables-only;-mllvm;--amdgpu-kernarg-preload-count=16"
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
example/54_tall_and_skinny_gemm_splitk/run_tall_and_skinny_gemm_splitk_example.inc
100644 → 100755
View file @
8861bd66
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#pragma once
#pragma once
bool
run_
tall_and_skinny_
gem
m
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
bool
run_gem
v
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
...
@@ -72,9 +72,9 @@ bool run_tall_and_skinny_gemm(const ProblemSize& problem_size, const ExecutionCo
...
@@ -72,9 +72,9 @@ bool run_tall_and_skinny_gemm(const ProblemSize& problem_size, const ExecutionCo
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
// do GEMM
auto
tsmm
=
Device
TSMM
Instance
{};
auto
gemv
=
Device
Gemv
Instance
{};
auto
invoker
=
tsmm
.
MakeInvoker
();
auto
invoker
=
gemv
.
MakeInvoker
();
auto
argument
=
tsmm
.
MakeArgument
(
auto
argument
=
gemv
.
MakeArgument
(
#ifdef BUILD_INT4_EXAMPLE
#ifdef BUILD_INT4_EXAMPLE
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
...
@@ -96,22 +96,24 @@ bool run_tall_and_skinny_gemm(const ProblemSize& problem_size, const ExecutionCo
...
@@ -96,22 +96,24 @@ bool run_tall_and_skinny_gemm(const ProblemSize& problem_size, const ExecutionCo
k_batch
);
// //
k_batch
);
// //
// //
// //
if
(
!
tsmm
.
IsSupportedArgument
(
argument
))
if
(
!
gemv
.
IsSupportedArgument
(
argument
))
{
{
std
::
cerr
<<
tsmm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
std
::
cerr
<<
gemv
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
return
true
;
}
}
c_m_n_device_buf
.
SetZero
();
c_m_n_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
// Run prior to verification
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
{
{
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
// Run prior to verification
auto
ref_tsmm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_tsmm
.
MakeInvoker
();
auto
ref_argument
=
ref_tsmm
.
MakeArgument
(
auto
ref_gemv
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemv
.
MakeInvoker
();
auto
ref_argument
=
ref_gemv
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
...
@@ -141,7 +143,7 @@ bool run_tall_and_skinny_gemm(const ProblemSize& problem_size, const ExecutionCo
...
@@ -141,7 +143,7 @@ bool run_tall_and_skinny_gemm(const ProblemSize& problem_size, const ExecutionCo
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
tsmm
.
GetTypeString
()
<<
std
::
endl
;
<<
gemv
.
GetTypeString
()
<<
std
::
endl
;
#ifdef BUILD_INT4_EXAMPLE
#ifdef BUILD_INT4_EXAMPLE
return
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
return
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
...
@@ -150,7 +152,7 @@ bool run_tall_and_skinny_gemm(const ProblemSize& problem_size, const ExecutionCo
...
@@ -150,7 +152,7 @@ bool run_tall_and_skinny_gemm(const ProblemSize& problem_size, const ExecutionCo
#endif
#endif
}
}
bool
run_
tall_and_skinny_
gem
m
_example
(
int
argc
,
char
*
argv
[])
bool
run_gem
v
_example
(
int
argc
,
char
*
argv
[])
{
{
ProblemSize
problem_size
;
ProblemSize
problem_size
;
ExecutionConfig
config
;
ExecutionConfig
config
;
...
@@ -190,5 +192,5 @@ bool run_tall_and_skinny_gemm_example(int argc, char* argv[])
...
@@ -190,5 +192,5 @@ bool run_tall_and_skinny_gemm_example(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
return
run_
tall_and_skinny_
gem
m
(
problem_size
,
config
);
return
run_gem
v
(
problem_size
,
config
);
}
}
include/ck/host_utility/kernel_launch.hpp
View file @
8861bd66
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include "ck/stream_config.hpp"
#include "ck/stream_config.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#ifndef KERNARG_PRELOAD
template
<
typename
...
Args
,
typename
F
>
template
<
typename
...
Args
,
typename
F
>
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
F
kernel
,
F
kernel
,
...
@@ -78,6 +79,80 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -78,6 +79,80 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#endif
#endif
}
}
#else
template
<
typename
...
Args
,
typename
F
>
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
// Args* args1;
// hipGetErrorString(hipMalloc(&args1, sizeof(Args)));
// hip_check_error(hipMemcpy(args1, &args, sizeof(Args), hipMemcpyHostToDevice));
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
{
#if DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up 1 time
\n
"
);
#endif
//
// warm up
const
int
nrepeat
=
1000
;
for
(
auto
i
=
0
;
i
<
nrepeat
;
i
++
)
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
hip_check_error
(
hipGetLastError
());
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
hipEvent_t
start
,
stop
;
float
total_time
=
0
;
hip_check_error
(
hipEventCreate
(
&
start
));
hip_check_error
(
hipEventCreate
(
&
stop
));
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
// hip_check_error(hipGetLastError());
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventSynchronize
(
stop
));
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
return
total_time
/
nrepeat
;
}
else
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
#endif
}
#endif
template
<
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
template
<
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
PreProcessFunc
preprocess
,
PreProcessFunc
preprocess
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
View file @
8861bd66
...
@@ -117,7 +117,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -117,7 +117,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
{
{
const
index_t
grid_size
=
GridwiseTsmm
::
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
const
index_t
grid_size
=
GridwiseTsmm
::
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
//
const auto b2c_map = DefaultBlock2CTileMap{};
const
auto
b2c_map
=
DefaultBlock2CTileMap
{};
const
auto
K0
=
karg
.
K0
;
const
auto
K0
=
karg
.
K0
;
...
@@ -138,24 +138,54 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -138,24 +138,54 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
true
,
true
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
true
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
...
@@ -166,24 +196,54 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -166,24 +196,54 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
true
,
true
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
true
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
...
@@ -193,24 +253,54 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -193,24 +253,54 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
false
,
false
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
false
,
false
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
}
}
else
else
...
@@ -220,30 +310,59 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -220,30 +310,59 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
false
,
false
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
false
,
false
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
}
}
return
ave_time
;
return
ave_time
;
}
}
// polymorphic
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
...
@@ -263,7 +382,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -263,7 +382,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
||
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
||
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
ck
::
get_device_name
()
==
"gfx1102"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
)
{
{
return
GridwiseTsmm
::
CheckValidity
(
arg
);
return
GridwiseTsmm
::
CheckValidity
(
arg
);
}
}
...
@@ -302,8 +422,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -302,8 +422,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
//
GridwiseTsmm::CalculateMPadded(M),
GridwiseTsmm
::
CalculateMPadded
(
M
),
//
GridwiseTsmm::CalculateNPadded(N),
GridwiseTsmm
::
CalculateNPadded
(
N
),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
KBatch
};
// //
KBatch
};
// //
...
@@ -336,8 +456,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -336,8 +456,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
//
GridwiseTsmm::CalculateMPadded(M),
GridwiseTsmm
::
CalculateMPadded
(
M
),
//
GridwiseTsmm::CalculateNPadded(N),
GridwiseTsmm
::
CalculateNPadded
(
N
),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
KBatch
);
// //
KBatch
);
// //
...
...
include/ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
View file @
8861bd66
...
@@ -21,23 +21,70 @@ namespace ck {
...
@@ -21,23 +21,70 @@ namespace ck {
template
<
typename
GridwiseTsmm
,
template
<
typename
GridwiseTsmm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
BLayout
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
typename
Block2CTileMap
>
typename
Block2CTileMap
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_tsmm_dl_v1r3
(
kernel_tsmm_dl_v1r3
(
typename
GridwiseTsmm
::
Argument
karg
)
//: in __global__ functions, struct is
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
K0
,
index_t
k_batch
,
index_t
MPadded
,
index_t
NPadded
,
const
Block2CTileMap
block_2_ctile_map
)
//: in __global__ functions, struct is
// better for reduced load overhead
// better for reduced load overhead
{
{
// strides depend on B's layout
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
HasDoubleTailKBlockLoop
,
GridwiseTsmm
,
CGlobalMemoryDataOperation
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
K0
,
k_batch
,
K
,
N
,
N
,
MPadded
,
NPadded
,
block_2_ctile_map
);
}
else
{
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
HasDoubleTailKBlockLoop
,
HasDoubleTailKBlockLoop
,
GridwiseTsmm
,
GridwiseTsmm
,
CGlobalMemoryDataOperation
>(
karg
);
CGlobalMemoryDataOperation
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
K0
,
k_batch
,
K
,
K
,
N
,
MPadded
,
NPadded
,
block_2_ctile_map
);
}
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
...
@@ -90,8 +137,8 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -90,8 +137,8 @@ struct GridwiseTsmmDl_km_kn_mn
index_t
StrideA_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
StrideC_
,
//
index_t MPadded_,
index_t
MPadded_
,
//
index_t NPadded_,
index_t
NPadded_
,
// index_t KPadded_,
// index_t KPadded_,
index_t
K0_
,
index_t
K0_
,
index_t
k_batch_
)
index_t
k_batch_
)
...
@@ -104,8 +151,8 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -104,8 +151,8 @@ struct GridwiseTsmmDl_km_kn_mn
StrideA
{
StrideA_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
StrideC
{
StrideC_
},
//
MPadded(MPadded_),
MPadded
(
MPadded_
),
//
NPadded(NPadded_),
NPadded
(
NPadded_
),
// KPadded(KPadded_),
// KPadded(KPadded_),
K0
(
K0_
),
K0
(
K0_
),
k_batch
(
k_batch_
)
k_batch
(
k_batch_
)
...
@@ -120,8 +167,8 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -120,8 +167,8 @@ struct GridwiseTsmmDl_km_kn_mn
index_t
M
,
N
,
K
;
index_t
M
,
N
,
K
;
index_t
StrideA
,
StrideB
,
StrideC
;
index_t
StrideA
,
StrideB
,
StrideC
;
//:
//:
//
index_t MPadded;
index_t
MPadded
;
//
index_t NPadded;
index_t
NPadded
;
// index_t KPadded;
// index_t KPadded;
index_t
K0
;
index_t
K0
;
index_t
k_batch
;
index_t
k_batch
;
...
@@ -320,12 +367,12 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -320,12 +367,12 @@ struct GridwiseTsmmDl_km_kn_mn
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
{
const
auto
MPadded
=
CalculateMPadded
(
karg
.
M
);
//
const auto MPadded = CalculateMPadded(karg.M);
const
auto
NPadded
=
CalculateNPadded
(
karg
.
N
);
//
const auto NPadded = CalculateNPadded(karg.N);
const
auto
a_grid_desc_kbatch_k0_m_k1
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
const
auto
a_grid_desc_kbatch_k0_m_k1
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
);
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
);
const
auto
b_grid_desc_kbatch_k0_n_k1
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
const
auto
b_grid_desc_kbatch_k0_n_k1
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
);
karg
.
K
,
karg
.
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
KBatch_a
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I0
);
const
auto
KBatch_a
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I0
);
...
@@ -433,27 +480,32 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -433,27 +480,32 @@ struct GridwiseTsmmDl_km_kn_mn
bool
HasDoubleTailKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
typename
GridwiseTsmm
,
typename
GridwiseTsmm
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__device__
static
void
Run
(
const
Argument
&
karg
)
__device__
static
void
Run
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
K0
,
index_t
k_batch
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
MPadded
,
index_t
NPadded
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseTsmm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseTsmm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
const
Block2CTileMap
&
block_2_ctile_map
=
Block2CTileMap
{};
const
auto
MPadded
=
CalculateMPadded
(
karg
.
M
);
const
auto
NPadded
=
CalculateNPadded
(
karg
.
N
);
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
const
auto
a_grid_desc_kbatch_k0_m_k1
=
GridwiseTsmm
::
MakeAGridDescriptor_KBatch_K0_M_K1
(
const
auto
a_grid_desc_kbatch_k0_m_k1
=
GridwiseTsmm
::
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
);
//
M
,
MPadded
,
K
,
StrideA
,
k_batch
,
K0
);
//
const
auto
b_grid_desc_kbatch_k0_n_k1
=
GridwiseTsmm
::
MakeBGridDescriptor_KBatch_K0_N_K1
(
const
auto
b_grid_desc_kbatch_k0_n_k1
=
GridwiseTsmm
::
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
);
//
K
,
NPadded
,
N
,
StrideB
,
k_batch
,
K0
);
//
const
auto
c_grid_desc_m_n
=
const
auto
c_grid_desc_m_n
=
GridwiseTsmm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
GridwiseTsmm
::
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
a_grid_desc_kbatch_k0_m0_m1_k1
=
const
auto
a_grid_desc_kbatch_k0_m0_m1_k1
=
GridwiseTsmm
::
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1
);
//
GridwiseTsmm
::
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1
);
//
...
@@ -470,8 +522,8 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -470,8 +522,8 @@ struct GridwiseTsmmDl_km_kn_mn
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
const
auto
c_m0_n0_block_cluster_idx
=
block_2_ctile_map
.
convert_1D_block_idx_to_3D_tuple
(
const
auto
c_m0_n0_block_cluster_idx
=
get_block_1d_id
(),
karg
.
N
,
karg
.
k_batch
);
block_2_ctile_map
.
convert_1D_block_idx_to_3D_tuple
(
get_block_1d_id
(),
N
,
k_batch
);
// HACK: this force index data into SGPR
// HACK: this force index data into SGPR
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
...
@@ -622,7 +674,7 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -622,7 +674,7 @@ struct GridwiseTsmmDl_km_kn_mn
if
constexpr
(
HasMainKBlockLoop
)
if
constexpr
(
HasMainKBlockLoop
)
{
{
const
auto
K0
=
a_grid_desc_kbatch_k0_m0_m1_k1
.
GetLength
(
I1
);
//
const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
index_t
k_block_data_begin
=
0
;
index_t
k_block_data_begin
=
0
;
...
...
gaoqiong
@gaoqiong
mentioned in commit
c2784145
·
Feb 18, 2025
mentioned in commit
c2784145
mentioned in commit c2784145d6d55c4accfad7760196b1eea80a4b36
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment