Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4d914af3
Commit
4d914af3
authored
Oct 31, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
223a2abe
4b798833
Changes
333
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
779 additions
and
4 deletions
+779
-4
python/ck4inductor/grouped_conv_fwd/op.py
python/ck4inductor/grouped_conv_fwd/op.py
+93
-0
python/ck4inductor/universal_gemm/gen_instances.py
python/ck4inductor/universal_gemm/gen_instances.py
+4
-1
python/ck4inductor/universal_gemm/op.py
python/ck4inductor/universal_gemm/op.py
+3
-0
python/ck4inductor/util.py
python/ck4inductor/util.py
+4
-1
script/convert_miopen_driver_to_profiler.py
script/convert_miopen_driver_to_profiler.py
+3
-2
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/ck_tile/CMakeLists.txt
test/ck_tile/CMakeLists.txt
+1
-0
test/ck_tile/gemm/CMakeLists.txt
test/ck_tile/gemm/CMakeLists.txt
+4
-0
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
+29
-0
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
+41
-0
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
+318
-0
test/scatter_gather/CMakeLists.txt
test/scatter_gather/CMakeLists.txt
+2
-0
test/scatter_gather/scatter_gather.cpp
test/scatter_gather/scatter_gather.cpp
+276
-0
No files found.
python/ck4inductor/grouped_conv_fwd/op.py
0 → 100644
View file @
4d914af3
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Optional
,
Tuple
@
dataclass
class
CKGroupedConvFwdOp
:
n_dim_spatial
:
int
a_layout
:
str
b_layout
:
str
ds_layout
:
Tuple
[
str
]
e_layout
:
str
a_element_dtype
:
str
b_element_dtype
:
str
acc_dtype
:
str
c_shuffle_dtype
:
str
ds_element_dtype
:
Tuple
[
str
]
e_element_dtype
:
str
a_elementwise_op
:
str
b_elementwise_op
:
str
cde_elementwise_op
:
str
conv_forward_specialization
:
str
gemm_specialization
:
str
block_size
:
int
m_per_block
:
int
n_per_block
:
int
k_per_block
:
int
ak1
:
int
bk1
:
int
m_per_xdl
:
int
n_per_xdl
:
int
m_xdl_per_wave
:
int
n_xdl_per_wave
:
int
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_vector_dim
:
int
a_block_transfer_src_scalar_per_vector
:
int
a_block_transfer_dst_scalar_per_vector_ak1
:
int
a_block_lds_extra_m
:
bool
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_vector_dim
:
int
b_block_transfer_src_scalar_per_vector
:
int
b_block_transfer_dst_scalar_per_vector_bk1
:
int
b_block_lds_extra_n
:
bool
c_shuffle_m_xdl_per_wave_per_shuffle
:
int
c_shuffle_n_xdl_per_wave_per_shuffle
:
int
cde_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
:
Tuple
[
# noqa
int
,
int
,
int
,
int
,
]
cde_block_transfer_scalar_per_vector_n_per_block
:
int
block_gemm_pipeline_scheduler
:
str
block_gemm_pipeline_version
:
str
a_compute_dtype
:
Optional
[
str
]
=
None
b_compute_dtype
:
Optional
[
str
]
=
None
def
name
(
self
):
# cpp alias for template instance
return
(
f
"ck_device_grouped_convolution_fwd_multiple_abd_xdl_c_shuffle_v3_"
f
"
{
self
.
key_name
()
}
"
)
def
key_name
(
self
):
# TBD; must be unique per instance. Intended to use as dict key
return
"_"
.
join
(
[
"K"
+
field_name
.
replace
(
"_"
,
""
).
lower
()
+
"V"
+
(
"x"
.
join
(
map
(
str
,
iter
(
field_value
)))
if
isinstance
(
field_value
,
tuple
)
else
str
(
field_value
).
replace
(
":"
,
""
)
)
for
field_name
,
field_value
in
self
.
dict_items
()
]
)
def
dict_items
(
self
):
return
asdict
(
self
).
items
()
python/ck4inductor/universal_gemm/gen_instances.py
View file @
4d914af3
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import
logging
import
logging
import
os
import
os
import
subprocess
import
subprocess
from
dataclasses
import
fields
,
replace
from
dataclasses
import
replace
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
,
partial
from
typing
import
List
from
typing
import
List
...
...
python/ck4inductor/universal_gemm/op.py
View file @
4d914af3
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
...
...
python/ck4inductor/util.py
View file @
4d914af3
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import
functools
import
functools
import
os
import
os
@
functools
.
lru_cache
(
None
)
@
functools
.
lru_cache
(
None
)
def
library_path
():
def
library_path
():
return
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'
library
'
)
return
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
library
"
)
script/convert_miopen_driver_to_profiler.py
View file @
4d914af3
...
@@ -65,8 +65,9 @@ def parse_data_type(args):
...
@@ -65,8 +65,9 @@ def parse_data_type(args):
if
args
.
ck_profier_op
==
"grouped_conv_fwd"
:
if
args
.
ck_profier_op
==
"grouped_conv_fwd"
:
args
.
data_type
=
3
args
.
data_type
=
3
if
args
.
data_type
==
"bfp16"
:
if
args
.
data_type
==
"bfp16"
:
if
args
.
ck_profier_op
==
"grouped_conv_bwd_weight"
or
\
if
args
.
ck_profier_op
==
"grouped_conv_bwd_weight"
:
args
.
ck_profier_op
==
"grouped_conv_bwd_data"
or
\
args
.
data_type
=
5
if
args
.
ck_profier_op
==
"grouped_conv_bwd_data"
or
\
args
.
ck_profier_op
==
"grouped_conv_fwd"
:
args
.
ck_profier_op
==
"grouped_conv_fwd"
:
args
.
data_type
=
2
args
.
data_type
=
2
...
...
test/CMakeLists.txt
View file @
4d914af3
...
@@ -210,3 +210,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL
...
@@ -210,3 +210,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL
add_subdirectory
(
smfmac_op
)
add_subdirectory
(
smfmac_op
)
endif
()
endif
()
add_subdirectory
(
position_embedding
)
add_subdirectory
(
position_embedding
)
add_subdirectory
(
scatter_gather
)
test/ck_tile/CMakeLists.txt
View file @
4d914af3
add_subdirectory
(
image_to_column
)
add_subdirectory
(
image_to_column
)
add_subdirectory
(
gemm
)
test/ck_tile/gemm/CMakeLists.txt
0 → 100644
View file @
4d914af3
# Currently ck_tile is only built on gfx9
if
(
GPU_TARGETS MATCHES
"gfx9"
)
add_gtest_executable
(
test_ck_tile_gemm_mem_pipeline test_gemm_mem_pipeline.cpp
)
endif
()
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_gemm_mem_pipeline_util.hpp"
using
F16
=
ck_tile
::
half_t
;
using
F32
=
float
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestCkTileGemmMemPipeline
,
KernelTypes
);
#include "test_gemm_mem_pipeline_ut_cases.inc"
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
0 → 100644
View file @
4d914af3
#pragma once
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
SmallM
)
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
320
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
MidLargeM
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
320
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
432
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
Regular
)
{
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
512
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
template
<
typename
Tuple
>
class
TestCkTileGemmMemPipeline
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
// TODO: expose tile size through test t-param ?
struct
gemm_basic_args
{
const
void
*
p_a
;
const
void
*
p_b
;
void
*
p_c
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
K
;
ck_tile
::
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
};
void
invoke_gemm
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// TODO: This should be parameterized in tests
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadC
=
true
;
constexpr
int
kBlockPerCu
=
1
;
// ===============================================
using
GemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
false
,
kPadC
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
ALayout
,
BLayout
,
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
has_hot_loop_v
,
tail_number_v
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Lunching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
};
if
(
has_hot_loop
)
{
// Tail pipeline One to Seven
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
}
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
2
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Two
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
3
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
4
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Four
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Four
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
5
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Five
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Five
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
6
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Six
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Six
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
7
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Seven
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
}
else
{
// Tail number always Full - #PrefetchStages
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
false
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
else
{
std
::
ostringstream
err
;
err
<<
"When there's no hot loop, this tail number
\"
"
<<
tail_num
<<
"
\"
is not supported! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
}
public:
std
::
vector
<
int
>
k_batches_
;
void
SetUp
()
override
{
k_batches_
=
{
1
};
}
void
Run
(
const
int
M
,
const
int
N
,
const
int
K
,
const
int
StrideA
=
0
,
const
int
StrideB
=
0
,
const
int
StrideC
=
0
)
{
for
(
auto
kb
:
k_batches_
)
{
RunSingle
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
kb
);
}
}
void
RunSingle
(
const
int
M
,
const
int
N
,
const
int
K
,
const
int
StrideA
,
const
int
StrideB
,
const
int
StrideC
,
int
kbatch
=
1
)
{
using
namespace
ck_tile
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
std
::
size_t
stride_A
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
std
::
size_t
stride_B
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
std
::
size_t
stride_C
=
f_get_default_stride
(
M
,
N
,
StrideC
,
CLayout
{});
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
stride_A
,
ALayout
{}));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
stride_B
,
BLayout
{}));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5
,
5
}(
a_m_k
);
ck_tile
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5
,
5
}(
b_k_n
);
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_m_n_dev_buf
(
c_m_n_dev_result
.
get_element_space_size_in_bytes
());
a_m_k_dev_buf
.
ToDevice
(
a_m_k
.
data
());
b_k_n_dev_buf
.
ToDevice
(
b_k_n
.
data
());
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
gemm_basic_args
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
kbatch
=
kbatch
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
stride_A
=
stride_A
;
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
invoke_gemm
(
args
,
ck_tile
::
stream_config
{
nullptr
,
false
});
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
c_m_n_host_ref
.
SetZero
();
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
);
EXPECT_TRUE
(
pass
);
}
};
test/scatter_gather/CMakeLists.txt
0 → 100644
View file @
4d914af3
add_test_executable
(
test_scatter_gather scatter_gather.cpp
)
# target_compile_options(test_scatter_gather PRIVATE -v --save-temps -Wno-gnu-line-marker)
test/scatter_gather/scatter_gather.cpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#ifndef TEST_SCATTER_GATHER_VERBOSE
#define TEST_SCATTER_GATHER_VERBOSE 1
#endif
#define HIP_CALL(call) \
do \
{ \
hipError_t err = call; \
if(err != hipSuccess) \
{ \
printf("[hiperror](%d) fail to call %s", static_cast<int>(err), #call); \
exit(0); \
} \
} while(0)
/*
TODO:
This is a simple design of scatter/gather through indexing transform, with limitations
We may design a scatter/gather adaptor layer directly inside tile window
*/
template
<
ck_tile
::
index_t
ROW_TILE_SIZE
=
8
,
ck_tile
::
index_t
COL_TILE_SIZE
=
32
*
8
,
ck_tile
::
index_t
BLOCK_SIZE
=
256
,
ck_tile
::
index_t
ALIGNMENT
=
8
,
typename
INDEX_BUF_TYPE
=
ck_tile
::
index_t
,
typename
DATA_TYPE
=
ck_tile
::
fp16_t
>
__global__
void
row_scatter_gather
(
const
INDEX_BUF_TYPE
*
src_row_idx_ptr
,
const
INDEX_BUF_TYPE
*
dst_row_idx_ptr
,
const
DATA_TYPE
*
src_ptr
,
DATA_TYPE
*
dst_ptr
,
ck_tile
::
index_t
n_row_total
,
ck_tile
::
index_t
/*n_row_select*/
,
ck_tile
::
index_t
n_cols
)
{
using
namespace
ck_tile
;
// some constexpr vars
constexpr
index_t
vec
=
ALIGNMENT
;
static_assert
(
COL_TILE_SIZE
%
vec
==
0
);
constexpr
index_t
col_lanes
=
COL_TILE_SIZE
/
vec
;
constexpr
index_t
warp_size
=
ck_tile
::
get_warp_size
();
static_assert
(
warp_size
%
col_lanes
==
0
);
constexpr
index_t
row_lanes
=
warp_size
/
col_lanes
;
constexpr
index_t
num_warps
=
BLOCK_SIZE
/
warp_size
;
static_assert
(
ROW_TILE_SIZE
%
(
num_warps
*
row_lanes
)
==
0
);
constexpr
index_t
row_repeat
=
ROW_TILE_SIZE
/
(
num_warps
*
row_lanes
);
static_assert
(
row_repeat
==
1
,
"currently indexing not support(and would be not performant) if row_repeat has more"
);
// tile partitioner
index_t
tile_col_idx
=
0
;
index_t
tile_row_idx
=
blockIdx
.
x
*
ROW_TILE_SIZE
;
// create our tild distribution, which tell us the location of different threads
constexpr
auto
src_dist
=
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
row_repeat
,
num_warps
,
row_lanes
>
,
sequence
<
col_lanes
,
vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
const
auto
coord
=
src_dist
.
calculate_index
();
const
auto
row_coord
=
coord
[
number
<
0
>
{}]
+
tile_row_idx
;
// load the current row index from the indexing buffer. we do not use ck_tile utility here
INDEX_BUF_TYPE
src_row_id
=
src_row_idx_ptr
[
row_coord
];
INDEX_BUF_TYPE
dst_row_id
=
dst_row_idx_ptr
[
row_coord
];
// printf("-- tid:%d, src_row_id:%d, dst_row_id:%d\n", static_cast<int>(threadIdx.x),
// static_cast<int>(src_row_id), static_cast<int>(dst_row_id));
const
auto
src_view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
src_ptr
,
make_tuple
(
n_row_total
,
n_cols
),
make_tuple
(
n_cols
,
1
),
number
<
vec
>
{},
// alignement
number
<
1
>
{});
const
auto
src_gather_view
=
transform_tensor_view
(
src_view
,
make_tuple
(
make_indexing_transform
(
n_row_total
,
src_row_id
),
// here we replace row_idx which is loaded from another buffer
make_pass_through_transform
(
n_cols
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
auto
src_tile
=
make_tile_window
(
src_gather_view
,
make_tuple
(
number
<
ROW_TILE_SIZE
>
{},
number
<
COL_TILE_SIZE
>
{}),
{
tile_row_idx
,
tile_col_idx
},
src_dist
);
const
auto
dst_view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dst_ptr
,
make_tuple
(
n_row_total
,
n_cols
),
make_tuple
(
n_cols
,
1
),
number
<
vec
>
{},
number
<
1
>
{});
const
auto
dst_scatter_view
=
transform_tensor_view
(
dst_view
,
make_tuple
(
make_indexing_transform
(
n_row_total
,
dst_row_id
),
// here we replace row_idx which is loaded from another buffer
make_pass_through_transform
(
n_cols
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
auto
dst_tile
=
make_tile_window
(
dst_scatter_view
,
make_tuple
(
number
<
ROW_TILE_SIZE
>
{},
number
<
COL_TILE_SIZE
>
{}),
{
tile_row_idx
,
tile_col_idx
},
src_dist
/*reuse distribution*/
);
// we finished descriptor construction and index calculation, now start load/store
for
(
auto
i
=
0
;
i
<
n_cols
;
i
+=
COL_TILE_SIZE
)
{
// note that scatter/gather are just the same API when doing load store as normal memory
// operation
auto
data
=
load_tile
(
src_tile
);
store_tile
(
dst_tile
,
data
);
move_tile_window
(
src_tile
,
{
number
<
0
>
{},
number
<
COL_TILE_SIZE
>
{}});
move_tile_window
(
dst_tile
,
{
number
<
0
>
{},
number
<
COL_TILE_SIZE
>
{}});
}
}
union
pixel
{
struct
__attribute__
((
packed
))
{
unsigned
int
r
:
6
;
unsigned
int
c
:
10
;
};
ushort
data
;
};
struct
unique_linear_rand
{
unique_linear_rand
(
int
capacity_
)
:
capacity
(
capacity_
)
{}
std
::
unordered_set
<
int
>
set
;
int
gen
()
{
if
(
static_cast
<
int
>
(
set
.
size
())
>=
capacity
)
{
printf
(
"overflow, but will give you an number as well
\n
"
);
return
std
::
rand
()
%
capacity
;
}
while
(
1
)
{
int
r
=
std
::
rand
()
%
capacity
;
if
(
set
.
count
(
r
)
==
1
)
{
continue
;
}
set
.
insert
(
r
);
return
r
;
}
}
int
capacity
;
};
int
main
()
{
int
row_total
=
64
;
int
row_select
=
8
*
2
;
int
col
=
256
*
2
;
using
fp16_t
=
ck_tile
::
fp16_t
;
constexpr
int
row_tile
=
8
;
constexpr
int
col_tile
=
256
;
fp16_t
*
src
=
reinterpret_cast
<
fp16_t
*>
(
malloc
(
row_total
*
col
*
sizeof
(
fp16_t
)));
for
(
int
i_r
=
0
;
i_r
<
row_total
;
i_r
++
)
{
for
(
int
i_c
=
0
;
i_c
<
col
;
i_c
++
)
{
int
i
=
i_r
*
col
+
i_c
;
pixel
p
;
p
.
r
=
i_r
;
p
.
c
=
i_c
;
ushort
d
=
p
.
data
;
src
[
i
]
=
ck_tile
::
bit_cast
<
fp16_t
>
(
d
);
// for simplicity, just cast
}
}
fp16_t
*
dst
=
reinterpret_cast
<
fp16_t
*>
(
malloc
(
row_total
*
col
*
sizeof
(
fp16_t
)));
int
*
src_idx
=
reinterpret_cast
<
int
*>
(
malloc
(
row_select
*
sizeof
(
int
)));
int
*
dst_idx
=
reinterpret_cast
<
int
*>
(
malloc
(
row_select
*
sizeof
(
int
)));
// std::srand(std::time(std::nullptr));
// std::srand(11935);
std
::
srand
(
std
::
time
(
nullptr
));
auto
src_gen
=
unique_linear_rand
(
row_total
);
auto
dst_gen
=
unique_linear_rand
(
row_total
);
// dst index must be unique. src is fine
for
(
int
i_r
=
0
;
i_r
<
row_select
;
i_r
++
)
{
src_idx
[
i_r
]
=
src_gen
.
gen
();
dst_idx
[
i_r
]
=
dst_gen
.
gen
();
}
void
*
dev_src
;
void
*
dev_dst
;
void
*
dev_src_idx
;
void
*
dev_dst_idx
;
HIP_CALL
(
hipMalloc
(
&
dev_src
,
row_total
*
col
*
sizeof
(
fp16_t
)));
HIP_CALL
(
hipMalloc
(
&
dev_dst
,
row_total
*
col
*
sizeof
(
fp16_t
)));
HIP_CALL
(
hipMalloc
(
&
dev_src_idx
,
row_select
*
sizeof
(
int
)));
HIP_CALL
(
hipMalloc
(
&
dev_dst_idx
,
row_select
*
sizeof
(
int
)));
HIP_CALL
(
hipMemcpy
(
dev_src
,
src
,
row_total
*
col
*
sizeof
(
fp16_t
),
hipMemcpyHostToDevice
));
HIP_CALL
(
hipMemcpy
(
dev_src_idx
,
src_idx
,
row_select
*
sizeof
(
int
),
hipMemcpyHostToDevice
));
HIP_CALL
(
hipMemcpy
(
dev_dst_idx
,
dst_idx
,
row_select
*
sizeof
(
int
),
hipMemcpyHostToDevice
));
constexpr
int
bdim
=
256
;
int
gdim
=
(
row_select
+
row_tile
-
1
)
/
row_tile
;
row_scatter_gather
<
row_tile
,
col_tile
><<<
gdim
,
bdim
>>>
(
reinterpret_cast
<
int
*>
(
dev_src_idx
),
reinterpret_cast
<
int
*>
(
dev_dst_idx
),
reinterpret_cast
<
fp16_t
*>
(
dev_src
),
reinterpret_cast
<
fp16_t
*>
(
dev_dst
),
row_total
,
row_select
,
col
);
HIP_CALL
(
hipMemcpy
(
dst
,
dev_dst
,
row_total
*
col
*
sizeof
(
fp16_t
),
hipMemcpyDeviceToHost
));
#if TEST_SCATTER_GATHER_VERBOSE
printf
(
"select row:"
);
for
(
int
i_r
=
0
;
i_r
<
row_select
;
i_r
++
)
{
printf
(
"%d->%d->%d "
,
i_r
,
src_idx
[
i_r
],
dst_idx
[
i_r
]);
}
printf
(
"
\n
"
);
#endif
int
err_cnt
=
0
;
for
(
int
i_r
=
0
;
i_r
<
row_select
;
i_r
++
)
{
for
(
int
i_c
=
0
;
i_c
<
col
;
i_c
++
)
{
int
i
=
dst_idx
[
i_r
]
*
col
+
i_c
;
pixel
p
=
ck_tile
::
bit_cast
<
pixel
>
(
dst
[
i
]);
bool
is_ok
=
p
.
r
==
src_idx
[
i_r
]
&&
p
.
c
==
i_c
;
if
(
!
is_ok
)
{
if
(
i_c
==
0
)
printf
(
"(%d)pixel: %dx%d -> %d
\n
"
,
i_r
,
p
.
r
,
p
.
c
,
dst_idx
[
i_r
]);
err_cnt
++
;
}
}
}
#if TEST_SCATTER_GATHER_VERBOSE
printf
(
"err:%d
\n
"
,
err_cnt
);
#endif
free
(
src
);
free
(
dst
);
free
(
src_idx
);
free
(
dst_idx
);
return
err_cnt
==
0
?
0
:
-
1
;
}
Prev
1
…
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment