Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c2784145
Commit
c2784145
authored
Dec 13, 2023
by
Harisankar Sadasivan
Browse files
Revert "kernarg load latency optimization for mi300"
This reverts commit
8861bd66
.
parent
8861bd66
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
74 additions
and
325 deletions
+74
-325
example/53_gemv_splitk/CMakeLists.txt
example/53_gemv_splitk/CMakeLists.txt
+0
-1
example/54_tall_and_skinny_gemm_splitk/CMakeLists.txt
example/54_tall_and_skinny_gemm_splitk/CMakeLists.txt
+0
-1
example/54_tall_and_skinny_gemm_splitk/run_tall_and_skinny_gemm_splitk_example.inc
...y_gemm_splitk/run_tall_and_skinny_gemm_splitk_example.inc
+13
-15
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+0
-75
include/ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
...on/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
+23
-143
include/ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
...eration/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
+38
-90
No files found.
example/53_gemv_splitk/CMakeLists.txt
View file @
c2784145
...
@@ -6,7 +6,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -6,7 +6,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_gemv_splitk_fp16 gemv_splitk_fp16.cpp
)
add_example_executable
(
example_gemv_splitk_fp16 gemv_splitk_fp16.cpp
)
add_dependencies
(
example_gemv_splitk
add_dependencies
(
example_gemv_splitk
example_gemv_splitk_fp16
)
example_gemv_splitk_fp16
)
set_source_files_properties
(
gemv_splitk_fp16.cpp PROPERTIES COMPILE_OPTIONS
"-DKERNARG_PRELOAD;-Wno-gnu-line-marker;-gline-tables-only;-mllvm;--amdgpu-kernarg-preload-count=16"
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
example/54_tall_and_skinny_gemm_splitk/CMakeLists.txt
View file @
c2784145
...
@@ -6,7 +6,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -6,7 +6,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp
)
add_example_executable
(
example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp
)
add_dependencies
(
example_tall_and_skinny_gemm_splitk
add_dependencies
(
example_tall_and_skinny_gemm_splitk
example_tall_and_skinny_gemm_splitk_fp16
)
example_tall_and_skinny_gemm_splitk_fp16
)
set_source_files_properties
(
tall_and_skinny_gemm_splitk_fp16.cpp PROPERTIES COMPILE_OPTIONS
"-DKERNARG_PRELOAD;-Wno-gnu-line-marker;-gline-tables-only;-mllvm;--amdgpu-kernarg-preload-count=16"
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
example/54_tall_and_skinny_gemm_splitk/run_tall_and_skinny_gemm_splitk_example.inc
100755 → 100644
View file @
c2784145
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#pragma once
#pragma once
bool
run_gem
v
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
bool
run_
tall_and_skinny_
gem
m
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
...
@@ -72,9 +72,9 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -72,9 +72,9 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
// do GEMM
auto
gemv
=
Device
Gemv
Instance
{};
auto
tsmm
=
Device
TSMM
Instance
{};
auto
invoker
=
gemv
.
MakeInvoker
();
auto
invoker
=
tsmm
.
MakeInvoker
();
auto
argument
=
gemv
.
MakeArgument
(
auto
argument
=
tsmm
.
MakeArgument
(
#ifdef BUILD_INT4_EXAMPLE
#ifdef BUILD_INT4_EXAMPLE
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
...
@@ -96,24 +96,22 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -96,24 +96,22 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
k_batch
);
// //
k_batch
);
// //
// //
// //
if
(
!
gemv
.
IsSupportedArgument
(
argument
))
if
(
!
tsmm
.
IsSupportedArgument
(
argument
))
{
{
std
::
cerr
<<
gemv
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
std
::
cerr
<<
tsmm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
return
true
;
}
}
c_m_n_device_buf
.
SetZero
();
c_m_n_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
// Run prior to verification
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
{
{
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
// Run prior to verification
auto
ref_tsmm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_tsmm
.
MakeInvoker
();
auto
ref_gemv
=
ReferenceGemmInstance
{};
auto
ref_argument
=
ref_tsmm
.
MakeArgument
(
auto
ref_invoker
=
ref_gemv
.
MakeInvoker
();
auto
ref_argument
=
ref_gemv
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
...
@@ -143,7 +141,7 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -143,7 +141,7 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemv
.
GetTypeString
()
<<
std
::
endl
;
<<
tsmm
.
GetTypeString
()
<<
std
::
endl
;
#ifdef BUILD_INT4_EXAMPLE
#ifdef BUILD_INT4_EXAMPLE
return
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
return
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
...
@@ -152,7 +150,7 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -152,7 +150,7 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
#endif
#endif
}
}
bool
run_gem
v
_example
(
int
argc
,
char
*
argv
[])
bool
run_
tall_and_skinny_
gem
m
_example
(
int
argc
,
char
*
argv
[])
{
{
ProblemSize
problem_size
;
ProblemSize
problem_size
;
ExecutionConfig
config
;
ExecutionConfig
config
;
...
@@ -192,5 +190,5 @@ bool run_gemv_example(int argc, char* argv[])
...
@@ -192,5 +190,5 @@ bool run_gemv_example(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
return
run_gem
v
(
problem_size
,
config
);
return
run_
tall_and_skinny_
gem
m
(
problem_size
,
config
);
}
}
include/ck/host_utility/kernel_launch.hpp
View file @
c2784145
...
@@ -9,7 +9,6 @@
...
@@ -9,7 +9,6 @@
#include "ck/stream_config.hpp"
#include "ck/stream_config.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#ifndef KERNARG_PRELOAD
template
<
typename
...
Args
,
typename
F
>
template
<
typename
...
Args
,
typename
F
>
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
F
kernel
,
F
kernel
,
...
@@ -79,80 +78,6 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -79,80 +78,6 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#endif
#endif
}
}
#else
template
<
typename
...
Args
,
typename
F
>
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
// Args* args1;
// hipGetErrorString(hipMalloc(&args1, sizeof(Args)));
// hip_check_error(hipMemcpy(args1, &args, sizeof(Args), hipMemcpyHostToDevice));
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
{
#if DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up 1 time
\n
"
);
#endif
//
// warm up
const
int
nrepeat
=
1000
;
for
(
auto
i
=
0
;
i
<
nrepeat
;
i
++
)
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
hip_check_error
(
hipGetLastError
());
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
hipEvent_t
start
,
stop
;
float
total_time
=
0
;
hip_check_error
(
hipEventCreate
(
&
start
));
hip_check_error
(
hipEventCreate
(
&
stop
));
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
// hip_check_error(hipGetLastError());
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventSynchronize
(
stop
));
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
return
total_time
/
nrepeat
;
}
else
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
#endif
}
#endif
template
<
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
template
<
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
PreProcessFunc
preprocess
,
PreProcessFunc
preprocess
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp
View file @
c2784145
...
@@ -117,7 +117,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -117,7 +117,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
{
{
const
index_t
grid_size
=
GridwiseTsmm
::
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
const
index_t
grid_size
=
GridwiseTsmm
::
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
const
auto
b2c_map
=
DefaultBlock2CTileMap
{};
//
const auto b2c_map = DefaultBlock2CTileMap{};
const
auto
K0
=
karg
.
K0
;
const
auto
K0
=
karg
.
K0
;
...
@@ -138,54 +138,24 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -138,54 +138,24 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
true
,
true
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
kernel
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
true
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
kernel
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
...
@@ -196,54 +166,24 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -196,54 +166,24 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
true
,
true
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
kernel
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
true
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
kernel
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
...
@@ -253,54 +193,24 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -253,54 +193,24 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
false
,
false
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
kernel
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
false
,
false
,
true
,
true
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
kernel
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
}
}
else
else
...
@@ -310,59 +220,30 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -310,59 +220,30 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
false
,
false
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
kernel
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
const
auto
kernel
=
kernel_tsmm_dl_v1r3
<
GridwiseTsmm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
BLayout
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
false
,
false
,
false
,
false
,
DefaultBlock2CTileMap
>
;
// //
DefaultBlock2CTileMap
>
;
// //
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
kernel
,
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
);
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
(
karg
.
M
),
(
karg
.
N
),
(
karg
.
K
),
(
karg
.
K0
),
(
karg
.
k_batch
),
karg
.
MPadded
,
karg
.
NPadded
,
b2c_map
);
}
}
}
}
return
ave_time
;
return
ave_time
;
}
}
// polymorphic
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
...
@@ -382,8 +263,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -382,8 +263,7 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
||
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
||
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx1102"
)
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
)
{
{
return
GridwiseTsmm
::
CheckValidity
(
arg
);
return
GridwiseTsmm
::
CheckValidity
(
arg
);
}
}
...
@@ -422,8 +302,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -422,8 +302,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
GridwiseTsmm
::
CalculateMPadded
(
M
),
//
GridwiseTsmm::CalculateMPadded(M),
GridwiseTsmm
::
CalculateNPadded
(
N
),
//
GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
KBatch
};
// //
KBatch
};
// //
...
@@ -456,8 +336,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
...
@@ -456,8 +336,8 @@ struct deviceTsmmDl : public DeviceTsmm<ALayout,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
GridwiseTsmm
::
CalculateMPadded
(
M
),
//
GridwiseTsmm::CalculateMPadded(M),
GridwiseTsmm
::
CalculateNPadded
(
N
),
//
GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
GridwiseTsmm
::
CalculateK0
(
K
,
KBatch
),
KBatch
);
// //
KBatch
);
// //
...
...
include/ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp
View file @
c2784145
...
@@ -21,70 +21,23 @@ namespace ck {
...
@@ -21,70 +21,23 @@ namespace ck {
template
<
typename
GridwiseTsmm
,
template
<
typename
GridwiseTsmm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
BLayout
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
typename
Block2CTileMap
>
typename
Block2CTileMap
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_tsmm_dl_v1r3
(
kernel_tsmm_dl_v1r3
(
const
FloatAB
*
p_a_grid
,
typename
GridwiseTsmm
::
Argument
karg
)
//: in __global__ functions, struct is
const
FloatAB
*
p_b_grid
,
// better for reduced load overhead
FloatC
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
K0
,
index_t
k_batch
,
index_t
MPadded
,
index_t
NPadded
,
const
Block2CTileMap
block_2_ctile_map
)
//: in __global__ functions, struct is
// better for reduced load overhead
{
{
// strides depend on B's layout
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
{
HasDoubleTailKBlockLoop
,
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
GridwiseTsmm
,
HasDoubleTailKBlockLoop
,
CGlobalMemoryDataOperation
>(
karg
);
GridwiseTsmm
,
CGlobalMemoryDataOperation
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
K0
,
k_batch
,
K
,
N
,
N
,
MPadded
,
NPadded
,
block_2_ctile_map
);
}
else
{
GridwiseTsmm
::
template
Run
<
HasMainKBlockLoop
,
HasDoubleTailKBlockLoop
,
GridwiseTsmm
,
CGlobalMemoryDataOperation
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
K0
,
k_batch
,
K
,
K
,
N
,
MPadded
,
NPadded
,
block_2_ctile_map
);
}
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
...
@@ -137,8 +90,8 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -137,8 +90,8 @@ struct GridwiseTsmmDl_km_kn_mn
index_t
StrideA_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
StrideC_
,
index_t
MPadded_
,
//
index_t MPadded_,
index_t
NPadded_
,
//
index_t NPadded_,
// index_t KPadded_,
// index_t KPadded_,
index_t
K0_
,
index_t
K0_
,
index_t
k_batch_
)
index_t
k_batch_
)
...
@@ -151,8 +104,8 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -151,8 +104,8 @@ struct GridwiseTsmmDl_km_kn_mn
StrideA
{
StrideA_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
StrideC
{
StrideC_
},
MPadded
(
MPadded_
),
//
MPadded(MPadded_),
NPadded
(
NPadded_
),
//
NPadded(NPadded_),
// KPadded(KPadded_),
// KPadded(KPadded_),
K0
(
K0_
),
K0
(
K0_
),
k_batch
(
k_batch_
)
k_batch
(
k_batch_
)
...
@@ -167,8 +120,8 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -167,8 +120,8 @@ struct GridwiseTsmmDl_km_kn_mn
index_t
M
,
N
,
K
;
index_t
M
,
N
,
K
;
index_t
StrideA
,
StrideB
,
StrideC
;
index_t
StrideA
,
StrideB
,
StrideC
;
//:
//:
index_t
MPadded
;
//
index_t MPadded;
index_t
NPadded
;
//
index_t NPadded;
// index_t KPadded;
// index_t KPadded;
index_t
K0
;
index_t
K0
;
index_t
k_batch
;
index_t
k_batch
;
...
@@ -367,12 +320,12 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -367,12 +320,12 @@ struct GridwiseTsmmDl_km_kn_mn
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
{
//
const auto MPadded = CalculateMPadded(karg.M);
const
auto
MPadded
=
CalculateMPadded
(
karg
.
M
);
//
const auto NPadded = CalculateNPadded(karg.N);
const
auto
NPadded
=
CalculateNPadded
(
karg
.
N
);
const
auto
a_grid_desc_kbatch_k0_m_k1
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
const
auto
a_grid_desc_kbatch_k0_m_k1
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
);
karg
.
M
,
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
);
const
auto
b_grid_desc_kbatch_k0_n_k1
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
const
auto
b_grid_desc_kbatch_k0_n_k1
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
karg
.
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
);
karg
.
K
,
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
KBatch_a
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I0
);
const
auto
KBatch_a
=
a_grid_desc_kbatch_k0_m_k1
.
GetLength
(
I0
);
...
@@ -480,32 +433,27 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -480,32 +433,27 @@ struct GridwiseTsmmDl_km_kn_mn
bool
HasDoubleTailKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
typename
GridwiseTsmm
,
typename
GridwiseTsmm
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__device__
static
void
Run
(
const
FloatAB
*
p_a_grid
,
__device__
static
void
Run
(
const
Argument
&
karg
)
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
K0
,
index_t
k_batch
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
MPadded
,
index_t
NPadded
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseTsmm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseTsmm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
const
Block2CTileMap
&
block_2_ctile_map
=
Block2CTileMap
{};
const
auto
MPadded
=
CalculateMPadded
(
karg
.
M
);
const
auto
NPadded
=
CalculateNPadded
(
karg
.
N
);
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
const
auto
a_grid_desc_kbatch_k0_m_k1
=
GridwiseTsmm
::
MakeAGridDescriptor_KBatch_K0_M_K1
(
const
auto
a_grid_desc_kbatch_k0_m_k1
=
GridwiseTsmm
::
MakeAGridDescriptor_KBatch_K0_M_K1
(
M
,
MPadded
,
K
,
StrideA
,
k_batch
,
K0
);
//
karg
.
M
,
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
);
//
const
auto
b_grid_desc_kbatch_k0_n_k1
=
GridwiseTsmm
::
MakeBGridDescriptor_KBatch_K0_N_K1
(
const
auto
b_grid_desc_kbatch_k0_n_k1
=
GridwiseTsmm
::
MakeBGridDescriptor_KBatch_K0_N_K1
(
K
,
NPadded
,
N
,
StrideB
,
k_batch
,
K0
);
//
karg
.
K
,
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
);
//
const
auto
c_grid_desc_m_n
=
GridwiseTsmm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
auto
c_grid_desc_m_n
=
GridwiseTsmm
::
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
a_grid_desc_kbatch_k0_m0_m1_k1
=
const
auto
a_grid_desc_kbatch_k0_m0_m1_k1
=
GridwiseTsmm
::
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1
);
//
GridwiseTsmm
::
MakeAGridDescriptor_Kbatch_K0_M0_M1_K1
(
a_grid_desc_kbatch_k0_m_k1
);
//
...
@@ -522,8 +470,8 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -522,8 +470,8 @@ struct GridwiseTsmmDl_km_kn_mn
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
const
auto
c_m0_n0_block_cluster_idx
=
const
auto
c_m0_n0_block_cluster_idx
=
block_2_ctile_map
.
convert_1D_block_idx_to_3D_tuple
(
block_2_ctile_map
.
convert_1D_block_idx_to_3D_tuple
(
get_block_1d_id
(),
N
,
k_batch
);
get_block_1d_id
(),
karg
.
N
,
karg
.
k_batch
);
// HACK: this force index data into SGPR
// HACK: this force index data into SGPR
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
...
@@ -559,8 +507,8 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -559,8 +507,8 @@ struct GridwiseTsmmDl_km_kn_mn
decltype
(
a_block_desc_copy_kbatch_k0_m0_m1_k1
),
// block tensor desc
decltype
(
a_block_desc_copy_kbatch_k0_m0_m1_k1
),
// block tensor desc
ABlockTransferSrcAccessOrder
,
// 5-dim
ABlockTransferSrcAccessOrder
,
// 5-dim
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1
,
// SrcVectorTensorLengths
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1
,
// SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1
,
// DstVectorTensorLengths
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1
,
// DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
ABlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DstVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DstVectorTensorContiguousDimOrder
false
,
false
,
...
@@ -661,7 +609,7 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -661,7 +609,7 @@ struct GridwiseTsmmDl_km_kn_mn
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
a_blockwise_copy
.
RunRead
(
a_grid_desc_kbatch_k0_m0_m1_k1
,
a_global_buf
);
// a_global_buf -> reg_tmp_buf
a_global_buf
);
// a_global_buf -> reg_tmp_buf
a_blockwise_copy
.
RunWrite
(
a_block_desc_copy_kbatch_k0_m0_m1_k1
,
a_blockwise_copy
.
RunWrite
(
a_block_desc_copy_kbatch_k0_m0_m1_k1
,
a_block_even_buf
);
// reg_tmp_buf->a_block_even_buf
a_block_even_buf
);
// reg_tmp_buf->a_block_even_buf
...
@@ -674,7 +622,7 @@ struct GridwiseTsmmDl_km_kn_mn
...
@@ -674,7 +622,7 @@ struct GridwiseTsmmDl_km_kn_mn
if
constexpr
(
HasMainKBlockLoop
)
if
constexpr
(
HasMainKBlockLoop
)
{
{
//
const auto K0 = a_grid_desc_kbatch_k0_m0_m1_k1.GetLength(I1);
const
auto
K0
=
a_grid_desc_kbatch_k0_m0_m1_k1
.
GetLength
(
I1
);
index_t
k_block_data_begin
=
0
;
index_t
k_block_data_begin
=
0
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment