Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
fe15fcc0
Commit
fe15fcc0
authored
Jul 31, 2024
by
Harisankar Sadasivan
Browse files
debugging prints added.
parent
6c5111b7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
95 additions
and
45 deletions
+95
-45
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-0
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
+2
-1
example/01_gemm/run_gemm_example_streamk_v2.inc
example/01_gemm/run_gemm_example_streamk_v2.inc
+30
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
...n/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
+33
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
+29
-36
No files found.
example/01_gemm/CMakeLists.txt
View file @
fe15fcc0
...
@@ -23,6 +23,7 @@ add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp)
...
@@ -23,6 +23,7 @@ add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_v2
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_v2
)
add_example_executable
(
example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp
)
target_compile_options
(
example_gemm_xdl_fp16_streamk_v3 PRIVATE -ggdb -O1 -march=native
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_streamk_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_streamk_v3
)
add_example_executable
(
example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_v3
)
...
...
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
View file @
fe15fcc0
...
@@ -8,7 +8,8 @@
...
@@ -8,7 +8,8 @@
using
ADataType
=
ck
::
half_t
;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
half_t
;
// using CShuffleDataType = ck::half_t;
using
CShuffleDataType
=
float
;
using
CDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
...
...
example/01_gemm/run_gemm_example_streamk_v2.inc
View file @
fe15fcc0
...
@@ -239,6 +239,25 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -239,6 +239,25 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
return
true
;
return
true
;
}
}
std
::
size_t
workspace_size
=
gemm
.
GetWorkSpaceSize
(
&
argument
);
if
(
workspace_size
!=
0
)
{
workspace
.
Realloc
(
workspace_size
);
gemm
.
SetWorkSpacePointer
(
&
argument
,
workspace
.
GetDeviceBuffer
());
}
// if(workspace_size != 0)
// {
// float* ws_ptr = reinterpret_cast<float*>(malloc(workspace_size));
// size_t ws_dwords = workspace_size / sizeof(float);
// workspace.FromDevice(ws_ptr);
// printf("ws size=%0zu\n",workspace_size);
// for(size_t i = 0; i < ws_dwords; i++)
// {
// uint32_t rere = reinterpret_cast<uint32_t*>(ws_ptr)[i];
// printf("%4lu : %f(0x%08x)\n", i, ws_ptr[i], rere);
// }
// free(ws_ptr);
// }
bool
pass
=
true
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
...
@@ -261,8 +280,15 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -261,8 +280,15 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
return
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
return
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
#else
#else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
printf
(
"device copy initiated
\n
"
);
// HS
if
((
workspace_size
!=
0
)
&&
(
Streamk_sel
>
0
))
{
printf
(
"entered if
\n
"
);
workspace
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
}
else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
printf
(
"device copy finished
\n
"
);
// HS
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
c_m_n_host_result
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
...
@@ -273,8 +299,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -273,8 +299,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
if
(
config
.
time_kernel
)
if
(
config
.
time_kernel
)
{
{
printf
(
"before running timing
\n
"
);
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
printf
(
"after running timing
\n
"
);
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
View file @
fe15fcc0
...
@@ -131,25 +131,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
...
@@ -131,25 +131,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
printf
(
"inside run
\n
"
);
if
(
stream_config
.
log_level_
>
0
)
if
(
stream_config
.
log_level_
>
0
)
{
{
arg
.
Print
();
arg
.
Print
();
}
}
printf
(
"done printing arg
\n
"
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
}
printf
(
"done checking arg validity
\n
"
);
float
ave_time
=
0
;
float
ave_time
=
0
;
index_t
k_grain
=
KPerBlock
;
index_t
k_grain
=
KPerBlock
;
index_t
K_split
=
(
arg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
index_t
K_split
=
(
arg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
printf
(
"done finding k_split
\n
"
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
if
constexpr
(
GridwiseGemm
::
Block2CTileMap_streamk
::
ReductionStrategy
==
if
constexpr
(
GridwiseGemm
::
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Atomic
)
StreamKReductionStrategy
::
Atomic
)
{
{
hipGetErrorString
(
hipMemsetAsync
(
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
}
...
@@ -216,12 +218,14 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
...
@@ -216,12 +218,14 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
hipGetErrorString
(
hipMemsetAsync
(
hipGetErrorString
(
hipMemsetAsync
(
workspace_semaphore
,
workspace_semaphore
,
0
,
0
,
arg
.
block_2_ctile_map_streamk
.
get_workspace_size_for_semaphore
(),
sizeof
(
uint32_t
),
//arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(),
stream_config
.
stream_id_
));
stream_config
.
stream_id_
));
};
};
printf
(
"before ave_time
\n
"
);
ave_time
=
launch_and_time_kernel_with_preprocess
(
ave_time
=
launch_and_time_kernel_with_preprocess
(
stream_config
,
preprocess
,
kernel
,
grid_dim
,
dim3
(
BlockSize
),
0
,
arg
);
stream_config
,
preprocess
,
kernel
,
grid_dim
,
dim3
(
BlockSize
),
0
,
arg
);
printf
(
"after ave_time
\n
"
);
}
}
}
}
};
};
...
@@ -242,7 +246,9 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
...
@@ -242,7 +246,9 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
minimum_occupancy
>
;
printf
(
"before running lambda
\n
"
);
Run
(
kernel
);
Run
(
kernel
);
printf
(
"after running lambda
\n
"
);
}
}
}
}
// Tail number could be One to Seven
// Tail number could be One to Seven
...
@@ -443,6 +449,28 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
...
@@ -443,6 +449,28 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
}
}
};
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
p_arg
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
if
constexpr
(
GridwiseGemm
::
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
return
p_arg
->
block_2_ctile_map_streamk
.
get_workspace_size
(
sizeof
(
GemmAccDataType
));
}
else
{
return
0
;
}
}
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
pArg_
->
p_workspace_
=
p_workspace
;
}
static
constexpr
bool
IsValidCompilationParameter
()
static
constexpr
bool
IsValidCompilationParameter
()
{
{
// TODO: properly implement this check
// TODO: properly implement this check
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
View file @
fe15fcc0
...
@@ -1191,6 +1191,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1191,6 +1191,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const
BElementwiseOperation
b_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
const
CElementwiseOperation
c_element_op
{};
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
// Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
// problem.N,
// problem.N,
// AK0Number * problem.KPadded,
// AK0Number * problem.KPadded,
...
@@ -1218,8 +1228,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1218,8 +1228,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
static_cast
<
uint32_t
>
(
block_idx
)
>=
block_2_ctile_map_streamk
.
dp_start_block_idx
&&
static_cast
<
uint32_t
>
(
block_idx
)
>=
block_2_ctile_map_streamk
.
dp_start_block_idx
&&
static_cast
<
uint32_t
>
(
block_idx
)
<
static_cast
<
uint32_t
>
(
block_idx
)
<
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
is_reduction_block
=
static_cast
<
uint32_t
>
(
block_idx
)
>=
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
num_k_block_main_loop
=
iter_end
-
iter_start
;
num_k_block_main_loop
=
iter_end
-
iter_start
;
...
@@ -1229,6 +1238,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1229,6 +1238,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
StreamKReductionStrategy
::
Reduction
)
{
{
is_reduction_block
=
static_cast
<
uint32_t
>
(
block_idx
)
>=
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
if
(
is_reduction_block
)
if
(
is_reduction_block
)
{
{
// descriptors
// descriptors
...
@@ -1347,7 +1358,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1347,7 +1358,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
CElementwiseOperation
{}};
CElementwiseOperation
{}};
// block synchronization
// block synchronization
wg_barrier
.
wait_eq
(
reduction_idx
,
tile_acc_offset_end
-
tile_acc_offset_start
);
wg_barrier
.
wait_eq
(
0
,
block_2_ctile_map_streamk
.
sk_num_blocks
);
// wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end -
// tile_acc_offset_start);
#if 0
#if 0
if(threadIdx.x == 0) {
if(threadIdx.x == 0) {
...
@@ -1428,7 +1441,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1428,7 +1441,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
partial_acc_store_step_m
);
partial_acc_store_step_m
);
}
}
}
}
return
;
continue
;
}
}
}
}
...
@@ -1446,25 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1446,25 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
block_work_idx
=
auto
block_work_idx
=
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
...
@@ -1764,7 +1758,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1764,7 +1758,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
CShuffleDataType
,
// typename SrcData,
CShuffleDataType
,
// typename SrcData,
CDataType
,
// typename DstData,
C
Shuffle
DataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
),
decltype
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
...
@@ -1881,17 +1875,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1881,17 +1875,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
}
});
});
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
if
(
is_sk_block
)
{
// increase the counter for this tile
workgroup_barrier
wg_barrier
(
p_semaphore
);
wg_barrier
.
inc
(
tile_idx
);
}
}
}
}
// exit condition
// exit condition
iter_end
-=
current_iter_length
;
iter_end
-=
current_iter_length
;
...
@@ -1905,7 +1888,17 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1905,7 +1888,17 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
// make sure next loop LDS is ready for use
// make sure next loop LDS is ready for use
block_sync_lds
();
block_sync_lds
();
}
}
}
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
if
(
is_sk_block
)
{
// increase the counter for this tile
workgroup_barrier
wg_barrier
(
p_semaphore
);
wg_barrier
.
inc
(
0
);
}
}
}
// for loop
}
}
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment