Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3eaadd61
Commit
3eaadd61
authored
May 16, 2024
by
letaoqin
Browse files
first
parent
a0ae1c61
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
566 additions
and
6 deletions
+566
-6
include/ck/stream_config.hpp
include/ck/stream_config.hpp
+9
-0
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+467
-0
profiler/include/profiler/profile_gemm_universal_impl.hpp
profiler/include/profiler/profile_gemm_universal_impl.hpp
+87
-5
No files found.
include/ck/stream_config.hpp
View file @
3eaadd61
...
...
@@ -17,3 +17,12 @@ struct StreamConfig
bool
flush_cache
=
false
;
int
rotating_count
=
1
;
};
struct
GemmConfig
{
int
tile_m
=
1
;
int
tile_n
=
1
;
int
split_k
=
1
;
int
stages
=
1
;
std
::
string
op_name
=
""
;
};
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
3eaadd61
...
...
@@ -33,7 +33,7 @@ struct BaseInvoker
{
return
float
{
0
};
}
virtual
int
GetOccupancy
(
const
BaseArgument
*
)
{
return
1
;
}
virtual
~
BaseInvoker
()
{}
};
...
...
@@ -67,6 +67,8 @@ struct BaseOperator
p_arg
->
p_workspace_
=
p_workspace
;
}
//virtual int GetOccupancy() { return 1; }
virtual
GemmConfig
GetConfig
()
{
return
GemmConfig
{
1
,
1
,
1
,
1
,
""
};
}
virtual
~
BaseOperator
()
{}
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
3eaadd61
...
...
@@ -129,6 +129,430 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// Invoker
struct
Invoker
:
public
BaseInvoker
{
int
GetOccupancy
(
const
BaseArgument
*
p_arg
)
override
{
int
occupancy
=
0
;
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
ignore
=
arg
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
,
arg
.
KBatch
);
index_t
k_grain
=
arg
.
KBatch
*
KPerBlock
;
index_t
K_split
=
(
arg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
hipError_t
rtn
;
rtn
=
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
GridwiseGemm
::
GetSharedMemoryNumberOfByte
());
hip_check_error
(
rtn
);
};
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
if
(
has_main_k_block_loop
)
{
// Tail number always full
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
}
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
else
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
}
else
{
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
}
return
occupancy
>
0
?
occupancy
:
1
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
...
...
@@ -741,6 +1165,49 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return
str
.
str
();
}
// static int GetOccupancy2()
// {
// int occupancy = 1;
// constexpr index_t minimum_occupancy =
// BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
// auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
// true,
// InMemoryDataOperationEnum::Set,
// minimum_occupancy>;
// hipError_t rtn;
// rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
// &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
// hip_check_error(rtn);
// return ++occupancy;
// }
// int GetOccupancy() override
// {
// int occupancy = 3;
// // constexpr index_t minimum_occupancy =
// // BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
// // const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
// // true,
// // InMemoryDataOperationEnum::Set,
// // minimum_occupancy>;
// // hipError_t rtn;
// // rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
// // &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
// // hip_check_error(rtn);
// return ++occupancy;
// }
GemmConfig
GetConfig
()
override
{
return
GemmConfig
{
MPerBlock
,
NPerBlock
,
1
,
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
,
GetTypeString
()};
}
};
}
// namespace device
...
...
profiler/include/profiler/profile_gemm_universal_impl.hpp
View file @
3eaadd61
...
...
@@ -146,7 +146,27 @@ bool profile_gemm_universal_impl(int do_verification,
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_kbatch
=
0
;
int
best_occupancy
=
0
;
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hipError_t
rtn
;
rtn
=
hipGetDevice
(
&
dev
);
hip_check_error
(
rtn
);
rtn
=
hipGetDeviceProperties
(
&
dev_prop
,
dev
);
hip_check_error
(
rtn
);
int
num_cu
=
dev_prop
.
multiProcessorCount
;
float
config_score
=
1
;
int
config_waves
=
INT_MAX
;
int
current_tile_m
=
0
;
int
current_occupancy
=
0
;
float
current_tflops
=
0
;
GemmConfig
best_config
;
ignore
=
config_score
;
ignore
=
config_waves
;
ignore
=
current_tile_m
;
ignore
=
best_config
;
// profile device GEMM instances
for
(
auto
&
op_ptr
:
op_ptrs
)
{
...
...
@@ -157,6 +177,10 @@ bool profile_gemm_universal_impl(int do_verification,
kbatch_list
=
{
KBatch
};
}
auto
candidate_config
=
op_ptr
->
GetConfig
();
int
num_tile_m
=
(
M
+
candidate_config
.
tile_m
-
1
)
/
candidate_config
.
tile_m
;
int
num_tile_n
=
(
N
+
candidate_config
.
tile_n
-
1
)
/
candidate_config
.
tile_n
;
for
(
std
::
size_t
i
=
0
;
i
<
kbatch_list
.
size
();
i
++
)
{
auto
kbatch_curr
=
kbatch_list
[
i
];
...
...
@@ -180,7 +204,22 @@ bool profile_gemm_universal_impl(int do_verification,
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
int
occupancy
=
invoker_ptr
->
GetOccupancy
(
argument_ptr
.
get
());
if
(
occupancy
==
0
)
continue
;
int
ctas_per_wave
=
occupancy
*
num_cu
;
int
ctas_for_problem
=
num_tile_m
*
num_tile_n
*
kbatch_curr
;
const
int
num_waves_total
=
(
ctas_for_problem
+
ctas_per_wave
-
1
)
/
ctas_per_wave
;
const
float
num_waves_fractional
=
ctas_for_problem
/
float
(
ctas_per_wave
);
const
float
current_score
=
float
(
num_waves_total
)
-
num_waves_fractional
;
std
::
cout
<<
"tile_m: "
<<
num_tile_m
<<
" tile_n: "
<<
num_tile_n
<<
" occupancy: "
<<
occupancy
<<
" current_score:"
<<
current_score
<<
" ctas_per_wave: "
<<
ctas_per_wave
<<
" ctas_for_problem: "
<<
ctas_for_problem
<<
" num_waves_total: "
<<
num_waves_total
<<
" num_waves_fractional: "
<<
num_waves_fractional
<<
" kbatch_curr: "
<<
kbatch_curr
<<
std
::
endl
;
// re-init C to zero before profiling next kernel
c_device_buf
.
SetZero
();
...
...
@@ -227,7 +266,8 @@ bool profile_gemm_universal_impl(int do_verification,
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
", occupancy: "
<<
occupancy
<<
" "
<<
op_name
<<
", KBatch "
<<
kbatch_curr
<<
std
::
endl
;
#if defined CK_ENABLE_FP8
...
...
@@ -256,6 +296,42 @@ bool profile_gemm_universal_impl(int do_verification,
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
best_kbatch
=
kbatch_curr
;
best_occupancy
=
occupancy
;
}
if
(
num_waves_total
>
1
&&
num_waves_total
<
10
)
{
if
((
current_score
<
config_score
)
||
((
config_waves
>
num_waves_total
)
&&
(
current_score
<
config_score
+
0.1
f
)))
{
best_config
.
tile_m
=
candidate_config
.
tile_m
;
best_config
.
tile_n
=
candidate_config
.
tile_n
;
best_config
.
stages
=
candidate_config
.
stages
;
best_config
.
split_k
=
kbatch_curr
;
best_config
.
op_name
=
op_name
;
config_score
=
current_score
;
current_tile_m
=
candidate_config
.
tile_m
;
config_waves
=
num_waves_total
;
current_occupancy
=
occupancy
;
current_tflops
=
tflops
;
}
// else if(abs(current_score - config_score) < 0.001f &&
// (best_config.stages < candidate_config.stages ||
// kbatch_curr < best_config.split_k ||
// current_tile_m < candidate_config.tile_m))
// {
// best_config.tile_m = candidate_config.tile_m;
// best_config.tile_n = candidate_config.tile_n;
// best_config.stages = candidate_config.stages;
// best_config.split_k = kbatch_curr;
// best_config.op_name = op_name;
// current_tile_m = candidate_config.tile_m;
// config_waves = num_waves_total;
// current_occupancy = occupancy;
// current_tflops = tflops;
// }
}
}
else
...
...
@@ -303,8 +379,14 @@ bool profile_gemm_universal_impl(int do_verification,
std
::
cout
<<
" M = "
<<
M
<<
" N = "
<<
N
<<
" K = "
<<
K
<<
" StrideA = "
<<
StrideA
<<
" StrideB = "
<<
StrideB
<<
" StrideC = "
<<
StrideC
<<
" KBatch = "
<<
best_kbatch
<<
" : "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
" : "
<<
" occupancy: "
<<
best_occupancy
<<
" "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
std
::
cout
<<
"tile_m: "
<<
best_config
.
tile_m
<<
" tile_n: "
<<
best_config
.
tile_n
<<
" split_k: "
<<
best_config
.
split_k
<<
" stages: "
<<
best_config
.
stages
<<
", config_score: "
<<
config_score
<<
", tflops: "
<<
current_tflops
<<
", current_occupancy: "
<<
current_occupancy
<<
" name: "
<<
best_config
.
op_name
<<
", KBatch "
<<
best_config
.
split_k
<<
std
::
endl
;
return
pass
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment