Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5899b0fc
Commit
5899b0fc
authored
Oct 20, 2022
by
Alan Turner
Browse files
Formatting
parent
e72ecc75
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
41 additions
and
38 deletions
+41
-38
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+6
-5
src/targets/gpu/jit/ck_batched_gemm.cpp
src/targets/gpu/jit/ck_batched_gemm.cpp
+4
-6
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
...argets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
+6
-5
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/ck_batched_gemm.hpp
.../gpu/kernels/include/migraphx/kernels/ck_batched_gemm.hpp
+20
-17
test/verify/0ck_batched_gemm.cpp
test/verify/0ck_batched_gemm.cpp
+3
-3
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
5899b0fc
...
...
@@ -131,16 +131,17 @@ struct find_ck_batched_gemm
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_batched_gemm
{
ins
->
get_operator
()},
ins
->
inputs
());
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_batched_gemm
{
ins
->
get_operator
()},
ins
->
inputs
());
}
};
}
// namespace
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
match
::
find_matches
(
mpm
,
find_ck_batched_gemm
{});
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
match
::
find_matches
(
mpm
,
find_ck_batched_gemm
{});
}
}
// namespace gpu
...
...
src/targets/gpu/jit/ck_batched_gemm.cpp
View file @
5899b0fc
...
...
@@ -131,10 +131,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
return
it
->
second
;
}
static
std
::
size_t
get_batch_stride
(
const
shape
&
s
)
{
return
s
.
strides
()[
s
.
strides
().
size
()
-
3
];
}
static
std
::
size_t
get_batch_stride
(
const
shape
&
s
)
{
return
s
.
strides
()[
s
.
strides
().
size
()
-
3
];
}
struct
ck_batched_gemm_compiler
:
compiler
<
ck_batched_gemm_compiler
>
{
...
...
@@ -186,7 +183,7 @@ struct ck_batched_gemm_compiler : compiler<ck_batched_gemm_compiler>
hip_compile_options
options
;
// batch_count
auto
out_lens
=
c_shape
.
lens
();
auto
out_lens
=
c_shape
.
lens
();
auto
batch_count
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
auto
batchStrideA
=
get_batch_stride
(
a_shape
);
...
...
@@ -209,7 +206,8 @@ struct ck_batched_gemm_compiler : compiler<ck_batched_gemm_compiler>
options
.
kernel_name
=
"ck_batched_gemm_kernel"
;
options
.
virtual_inputs
=
inputs
;
auto
src
=
interpolate_string
(
ck_batched_gemm_kernel
,
{{
"instance"
,
join_strings
(
instance
,
","
)}});
auto
src
=
interpolate_string
(
ck_batched_gemm_kernel
,
{{
"instance"
,
join_strings
(
instance
,
","
)}});
return
compile_hip_code_object
(
src
,
options
);
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
View file @
5899b0fc
...
...
@@ -111,12 +111,13 @@ constexpr F for_each(Iterator first, Iterator last, F f)
}
template
<
class
Iterator
,
class
T
>
constexpr
void
fill
(
Iterator
first
,
Iterator
last
,
const
T
&
val
)
constexpr
void
fill
(
Iterator
first
,
Iterator
last
,
const
T
&
val
)
{
while
(
first
!=
last
)
{
*
first
=
val
;
++
first
;
}
while
(
first
!=
last
)
{
*
first
=
val
;
++
first
;
}
}
template
<
class
Iterator
,
class
Predicate
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
View file @
5899b0fc
...
...
@@ -62,10 +62,10 @@ constexpr auto to_ck_tensor()
template
<
class
Tensor
>
constexpr
auto
to_ck_batched_tensor
()
{
constexpr
auto
s
=
get_shape_c
<
Tensor
>
{};
constexpr
auto
s
=
get_shape_c
<
Tensor
>
{};
constexpr
auto
sz
=
s
.
lens
.
size
();
return
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
s
.
lens
[
sz
-
2
],
s
.
lens
[
sz
-
1
]),
ck
::
make_tuple
(
s
.
strides
[
sz
-
2
],
s
.
strides
[
sz
-
1
]));
ck
::
make_tuple
(
s
.
strides
[
sz
-
2
],
s
.
strides
[
sz
-
1
]));
}
template
<
class
F
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_batched_gemm.hpp
View file @
5899b0fc
...
...
@@ -53,13 +53,13 @@ template <ck::index_t NumDTensor>
struct
ComputePtrOffsetOfStridedBatch
{
__device__
ComputePtrOffsetOfStridedBatch
(
ck
::
index_t
BatchStrideA
,
ck
::
index_t
BatchStrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
ck
::
index_t
BatchStrideE
)
ck
::
index_t
BatchStrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
ck
::
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
}
...
...
@@ -94,15 +94,17 @@ struct ComputePtrOffsetOfStridedBatch
ck
::
index_t
BatchStrideE_
;
};
template
<
class
G
,
class
Settings
,
class
A
,
class
B
,
class
E
,
class
...
Ds
>
__device__
void
ck_batched_gemm
(
Settings
s
,
A
a
,
B
b
,
E
e
,
Ds
...
ds
)
{
constexpr
const
G
gemm
{};
constexpr
const
auto
a_grid_desc_m_k
=
gemm
.
matrix_padder
.
PadADescriptor_M_K
(
to_ck_batched_tensor
<
A
>
());
constexpr
const
auto
b_grid_desc_n_k
=
gemm
.
matrix_padder
.
PadBDescriptor_N_K
(
to_ck_batched_tensor
<
B
>
());
constexpr
const
auto
e_grid_desc_m_n
=
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_batched_tensor
<
E
>
());
constexpr
const
auto
a_grid_desc_m_k
=
gemm
.
matrix_padder
.
PadADescriptor_M_K
(
to_ck_batched_tensor
<
A
>
());
constexpr
const
auto
b_grid_desc_n_k
=
gemm
.
matrix_padder
.
PadBDescriptor_N_K
(
to_ck_batched_tensor
<
B
>
());
constexpr
const
auto
e_grid_desc_m_n
=
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_batched_tensor
<
E
>
());
constexpr
const
auto
ds_grid_desc_m_n
=
ck
::
make_tuple
(
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_batched_tensor
<
Ds
>
())...);
constexpr
const
auto
block_2_etile_map
=
gemm
.
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
);
...
...
@@ -124,17 +126,18 @@ __device__ void ck_batched_gemm(Settings s, A a, B b, E e, Ds... ds)
constexpr
const
bool
HasMainKBlockLoop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
0
>
{})
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
2
>
{}));
static
constexpr
ck
::
index_t
NumDTensor
=
gemm
.
NumDTensor
;
std
::
array
<
ck
::
index_t
,
NumDTensor
>
batchStrideDs
;
ck
::
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
batchStrideDs
[
i
]
=
s
.
batchStrideC
;
});
const
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch
{
s
.
batchStrideA
,
s
.
batchStrideB
,
batchStrideDs
,
s
.
batchStrideC
};
ck
::
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
batchStrideDs
[
i
]
=
s
.
batchStrideC
;
});
const
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch
{
s
.
batchStrideA
,
s
.
batchStrideB
,
batchStrideDs
,
s
.
batchStrideC
};
auto
batch_count
=
s
.
batch_count
;
const
ck
::
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
ck
::
get_grid_size
()
/
batch_count
);
const
ck
::
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
ck
::
get_block_1d_id
()
/
num_blocks_per_batch
);
const
ck
::
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
ck
::
get_block_1d_id
()
/
num_blocks_per_batch
);
const
ck
::
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
ck
::
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
...
...
@@ -154,7 +157,7 @@ __device__ void ck_batched_gemm(Settings s, A a, B b, E e, Ds... ds)
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
a
.
data
()
+
a_batch_offset
,
b
.
data
()
+
b_batch_offset
,
p_ds_grid_grp
,
p_ds_grid_grp
,
e
.
data
()
+
e_batch_offset
,
p_shared
,
gemm
.
a_element_op
,
...
...
test/verify/0ck_batched_gemm.cpp
View file @
5899b0fc
...
...
@@ -32,14 +32,14 @@ struct ck_batched_gemm : verify_program<ck_batched_gemm>
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
std
::
size_t
b
=
2
;
std
::
size_t
m
=
3
;
std
::
size_t
n
=
3
;
std
::
size_t
k
=
3
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
b
,
m
,
k
}};
std
::
vector
<
float
>
v1
(
b
*
m
*
k
,
1
);
std
::
vector
<
float
>
v2
(
b
*
k
*
n
,
1
);
//{1, 2, 3, 4, 5, 6, 7, 8};
std
::
vector
<
float
>
v1
(
b
*
m
*
k
,
1
);
std
::
vector
<
float
>
v2
(
b
*
k
*
n
,
1
);
//{1, 2, 3, 4, 5, 6, 7, 8};
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m1_shape);
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
m1_shape
,
v1
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment