Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3eaadd61
Commit
3eaadd61
authored
May 16, 2024
by
letaoqin
Browse files
first
parent
a0ae1c61
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
566 additions
and
6 deletions
+566
-6
include/ck/stream_config.hpp
include/ck/stream_config.hpp
+9
-0
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+467
-0
profiler/include/profiler/profile_gemm_universal_impl.hpp
profiler/include/profiler/profile_gemm_universal_impl.hpp
+87
-5
No files found.
include/ck/stream_config.hpp
View file @
3eaadd61
...
@@ -17,3 +17,12 @@ struct StreamConfig
...
@@ -17,3 +17,12 @@ struct StreamConfig
bool
flush_cache
=
false
;
bool
flush_cache
=
false
;
int
rotating_count
=
1
;
int
rotating_count
=
1
;
};
};
struct
GemmConfig
{
int
tile_m
=
1
;
int
tile_n
=
1
;
int
split_k
=
1
;
int
stages
=
1
;
std
::
string
op_name
=
""
;
};
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
3eaadd61
...
@@ -33,7 +33,7 @@ struct BaseInvoker
...
@@ -33,7 +33,7 @@ struct BaseInvoker
{
{
return
float
{
0
};
return
float
{
0
};
}
}
virtual
int
GetOccupancy
(
const
BaseArgument
*
)
{
return
1
;
}
virtual
~
BaseInvoker
()
{}
virtual
~
BaseInvoker
()
{}
};
};
...
@@ -67,6 +67,8 @@ struct BaseOperator
...
@@ -67,6 +67,8 @@ struct BaseOperator
p_arg
->
p_workspace_
=
p_workspace
;
p_arg
->
p_workspace_
=
p_workspace
;
}
}
//virtual int GetOccupancy() { return 1; }
virtual
GemmConfig
GetConfig
()
{
return
GemmConfig
{
1
,
1
,
1
,
1
,
""
};
}
virtual
~
BaseOperator
()
{}
virtual
~
BaseOperator
()
{}
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
3eaadd61
This diff is collapsed.
Click to expand it.
profiler/include/profiler/profile_gemm_universal_impl.hpp
View file @
3eaadd61
...
@@ -146,7 +146,27 @@ bool profile_gemm_universal_impl(int do_verification,
...
@@ -146,7 +146,27 @@ bool profile_gemm_universal_impl(int do_verification,
float
best_tflops
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_gb_per_sec
=
0
;
float
best_kbatch
=
0
;
float
best_kbatch
=
0
;
int
best_occupancy
=
0
;
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hipError_t
rtn
;
rtn
=
hipGetDevice
(
&
dev
);
hip_check_error
(
rtn
);
rtn
=
hipGetDeviceProperties
(
&
dev_prop
,
dev
);
hip_check_error
(
rtn
);
int
num_cu
=
dev_prop
.
multiProcessorCount
;
float
config_score
=
1
;
int
config_waves
=
INT_MAX
;
int
current_tile_m
=
0
;
int
current_occupancy
=
0
;
float
current_tflops
=
0
;
GemmConfig
best_config
;
ignore
=
config_score
;
ignore
=
config_waves
;
ignore
=
current_tile_m
;
ignore
=
best_config
;
// profile device GEMM instances
// profile device GEMM instances
for
(
auto
&
op_ptr
:
op_ptrs
)
for
(
auto
&
op_ptr
:
op_ptrs
)
{
{
...
@@ -157,6 +177,10 @@ bool profile_gemm_universal_impl(int do_verification,
...
@@ -157,6 +177,10 @@ bool profile_gemm_universal_impl(int do_verification,
kbatch_list
=
{
KBatch
};
kbatch_list
=
{
KBatch
};
}
}
auto
candidate_config
=
op_ptr
->
GetConfig
();
int
num_tile_m
=
(
M
+
candidate_config
.
tile_m
-
1
)
/
candidate_config
.
tile_m
;
int
num_tile_n
=
(
N
+
candidate_config
.
tile_n
-
1
)
/
candidate_config
.
tile_n
;
for
(
std
::
size_t
i
=
0
;
i
<
kbatch_list
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
kbatch_list
.
size
();
i
++
)
{
{
auto
kbatch_curr
=
kbatch_list
[
i
];
auto
kbatch_curr
=
kbatch_list
[
i
];
...
@@ -180,7 +204,22 @@ bool profile_gemm_universal_impl(int do_verification,
...
@@ -180,7 +204,22 @@ bool profile_gemm_universal_impl(int do_verification,
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
int
occupancy
=
invoker_ptr
->
GetOccupancy
(
argument_ptr
.
get
());
if
(
occupancy
==
0
)
continue
;
int
ctas_per_wave
=
occupancy
*
num_cu
;
int
ctas_for_problem
=
num_tile_m
*
num_tile_n
*
kbatch_curr
;
const
int
num_waves_total
=
(
ctas_for_problem
+
ctas_per_wave
-
1
)
/
ctas_per_wave
;
const
float
num_waves_fractional
=
ctas_for_problem
/
float
(
ctas_per_wave
);
const
float
current_score
=
float
(
num_waves_total
)
-
num_waves_fractional
;
std
::
cout
<<
"tile_m: "
<<
num_tile_m
<<
" tile_n: "
<<
num_tile_n
<<
" occupancy: "
<<
occupancy
<<
" current_score:"
<<
current_score
<<
" ctas_per_wave: "
<<
ctas_per_wave
<<
" ctas_for_problem: "
<<
ctas_for_problem
<<
" num_waves_total: "
<<
num_waves_total
<<
" num_waves_fractional: "
<<
num_waves_fractional
<<
" kbatch_curr: "
<<
kbatch_curr
<<
std
::
endl
;
// re-init C to zero before profiling next kernel
// re-init C to zero before profiling next kernel
c_device_buf
.
SetZero
();
c_device_buf
.
SetZero
();
...
@@ -227,7 +266,8 @@ bool profile_gemm_universal_impl(int do_verification,
...
@@ -227,7 +266,8 @@ bool profile_gemm_universal_impl(int do_verification,
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
", occupancy: "
<<
occupancy
<<
" "
<<
op_name
<<
", KBatch "
<<
kbatch_curr
<<
std
::
endl
;
<<
kbatch_curr
<<
std
::
endl
;
#if defined CK_ENABLE_FP8
#if defined CK_ENABLE_FP8
...
@@ -256,6 +296,42 @@ bool profile_gemm_universal_impl(int do_verification,
...
@@ -256,6 +296,42 @@ bool profile_gemm_universal_impl(int do_verification,
best_ave_time
=
ave_time
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
best_gb_per_sec
=
gb_per_sec
;
best_kbatch
=
kbatch_curr
;
best_kbatch
=
kbatch_curr
;
best_occupancy
=
occupancy
;
}
if
(
num_waves_total
>
1
&&
num_waves_total
<
10
)
{
if
((
current_score
<
config_score
)
||
((
config_waves
>
num_waves_total
)
&&
(
current_score
<
config_score
+
0.1
f
)))
{
best_config
.
tile_m
=
candidate_config
.
tile_m
;
best_config
.
tile_n
=
candidate_config
.
tile_n
;
best_config
.
stages
=
candidate_config
.
stages
;
best_config
.
split_k
=
kbatch_curr
;
best_config
.
op_name
=
op_name
;
config_score
=
current_score
;
current_tile_m
=
candidate_config
.
tile_m
;
config_waves
=
num_waves_total
;
current_occupancy
=
occupancy
;
current_tflops
=
tflops
;
}
// else if(abs(current_score - config_score) < 0.001f &&
// (best_config.stages < candidate_config.stages ||
// kbatch_curr < best_config.split_k ||
// current_tile_m < candidate_config.tile_m))
// {
// best_config.tile_m = candidate_config.tile_m;
// best_config.tile_n = candidate_config.tile_n;
// best_config.stages = candidate_config.stages;
// best_config.split_k = kbatch_curr;
// best_config.op_name = op_name;
// current_tile_m = candidate_config.tile_m;
// config_waves = num_waves_total;
// current_occupancy = occupancy;
// current_tflops = tflops;
// }
}
}
}
}
else
else
...
@@ -303,8 +379,14 @@ bool profile_gemm_universal_impl(int do_verification,
...
@@ -303,8 +379,14 @@ bool profile_gemm_universal_impl(int do_verification,
std
::
cout
<<
" M = "
<<
M
<<
" N = "
<<
N
<<
" K = "
<<
K
<<
" StrideA = "
<<
StrideA
std
::
cout
<<
" M = "
<<
M
<<
" N = "
<<
N
<<
" K = "
<<
K
<<
" StrideA = "
<<
StrideA
<<
" StrideB = "
<<
StrideB
<<
" StrideC = "
<<
StrideC
<<
" KBatch = "
<<
best_kbatch
<<
" StrideB = "
<<
StrideB
<<
" StrideC = "
<<
StrideC
<<
" KBatch = "
<<
best_kbatch
<<
" : "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" : "
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
" occupancy: "
<<
best_occupancy
<<
" "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
std
::
cout
<<
"tile_m: "
<<
best_config
.
tile_m
<<
" tile_n: "
<<
best_config
.
tile_n
<<
" split_k: "
<<
best_config
.
split_k
<<
" stages: "
<<
best_config
.
stages
<<
", config_score: "
<<
config_score
<<
", tflops: "
<<
current_tflops
<<
", current_occupancy: "
<<
current_occupancy
<<
" name: "
<<
best_config
.
op_name
<<
", KBatch "
<<
best_config
.
split_k
<<
std
::
endl
;
return
pass
;
return
pass
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment