Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e72ecc75
Commit
e72ecc75
authored
Oct 20, 2022
by
Alan Turner
Browse files
Add batched gemm
parent
54dd72b6
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
767 additions
and
1 deletion
+767
-1
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+66
-1
src/targets/gpu/jit/ck_batched_gemm.cpp
src/targets/gpu/jit/ck_batched_gemm.cpp
+229
-0
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
...argets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
+9
-0
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
+9
-0
src/targets/gpu/kernels/include/migraphx/kernels/ck_batched_gemm.hpp
.../gpu/kernels/include/migraphx/kernels/ck_batched_gemm.hpp
+171
-0
src/targets/gpu/kernels/include/migraphx/kernels/ck_batched_gemm_includes.hpp
...els/include/migraphx/kernels/ck_batched_gemm_includes.hpp
+232
-0
test/verify/0ck_batched_gemm.cpp
test/verify/0ck_batched_gemm.cpp
+51
-0
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
e72ecc75
...
...
@@ -46,6 +46,41 @@ struct ck_gemm
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
struct
ck_batched_gemm
{
operation
op
=
make_op
(
"dot"
);
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
));
}
std
::
string
name
()
const
{
return
"gpu::ck_batched_gemm"
;
}
void
check_gemm_shape
(
const
shape
&
s
)
const
{
if
(
contains
(
s
.
lens
(),
1
))
MIGRAPHX_THROW
(
"Invalid shape for ck_batched_gemm"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
not_broadcasted
();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
n
=
inputs
.
size
();
auto
a
=
inputs
[
n
-
2
];
auto
b
=
inputs
[
n
-
1
];
check_gemm_shape
(
a
);
check_gemm_shape
(
b
);
return
op
.
compute_shape
({
a
,
b
});
}
};
MIGRAPHX_REGISTER_OP
(
ck_batched_gemm
);
namespace
{
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
...
...
@@ -62,6 +97,20 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return
true
;
}
MIGRAPHX_PRED_MATCHER
(
is_ck_batched_gemm
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
if
(
a
.
lens
().
size
()
<
3
or
b
.
lens
().
size
()
<
3
)
return
false
;
if
(
a
.
lens
().
back
()
>
1024
)
return
false
;
return
true
;
}
struct
find_ck_gemm
{
// Find a gemm that can be replaced with a ck_gemm
...
...
@@ -74,9 +123,25 @@ struct find_ck_gemm
}
};
struct
find_ck_batched_gemm
{
// Find a batched gemm that can be replaced with a ck_batched_gemm
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
is_ck_batched_gemm
().
bind
(
"gemm"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_batched_gemm
{
ins
->
get_operator
()},
ins
->
inputs
());
}
};
}
// namespace
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
}
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
match
::
find_matches
(
mpm
,
find_ck_batched_gemm
{});
}
}
// namespace gpu
...
...
src/targets/gpu/jit/ck_batched_gemm.cpp
0 → 100644
View file @
e72ecc75
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <fstream>
#include <filesystem>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/env.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
std
::
vector
<
std
::
string
>&
get_instance
(
std
::
size_t
i
,
const
std
::
function
<
bool
(
const
std
::
vector
<
std
::
string
>&
)
>&
pred
);
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_LOG_CK_GEMM
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_TUNING
);
// NOLINTNEXTLINE
static
const
char
*
const
ck_batched_gemm_kernel
=
R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_batched_gemm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
namespace migraphx {
extern "C" {
__global__ void ck_batched_gemm_kernel(void* a_p, void* b_p, void* c_p)
{
make_tensors()(a_p, b_p, c_p)([&](auto a, auto b, auto c) {
auto settings = make_ck_batched_gemm_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_COUNT}),
MIGRAPHX_MAKE_CONSTANT(int64_t{BATCHSTRIDEA}),
MIGRAPHX_MAKE_CONSTANT(int64_t{BATCHSTRIDEB}),
MIGRAPHX_MAKE_CONSTANT(int64_t{BATCHSTRIDEC}));
ck_batched_gemm<CK_DeviceBatchedGemmMultipleD<${instance}>>(settings, a, b, c);
});
}
}
} // namespace migraphx
)__migraphx__"
;
static
std
::
size_t
int_div_ceil
(
std
::
size_t
x
,
std
::
size_t
y
)
{
return
(
x
+
y
-
1
)
/
y
;
}
static
std
::
size_t
block_size_index
=
15
;
static
std
::
size_t
padding_index
=
13
;
static
std
::
size_t
get_block_size
(
const
std
::
vector
<
std
::
string
>&
s
)
{
return
std
::
stoull
(
s
[
block_size_index
]);
}
static
std
::
size_t
get_grid_size
(
const
std
::
vector
<
std
::
string
>&
s
,
std
::
size_t
m
,
std
::
size_t
n
)
{
auto
mpb
=
std
::
stoull
(
s
[
block_size_index
+
1
]);
auto
npb
=
std
::
stoull
(
s
[
block_size_index
+
2
]);
return
int_div_ceil
(
m
,
mpb
)
*
int_div_ceil
(
n
,
npb
);
}
static
void
set_padding
(
std
::
vector
<
std
::
string
>&
s
,
const
std
::
string
p
)
{
s
[
padding_index
]
=
p
;
}
template
<
class
F
,
class
Action
>
auto
action_decorate
(
F
f
,
Action
action
)
{
return
[
=
](
auto
&&
...
xs
)
{
action
();
f
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
};
}
using
tuning_entry
=
std
::
pair
<
std
::
vector
<
shape
>
,
size_t
>
;
static
std
::
vector
<
tuning_entry
>
read_tuning
(
const
std
::
string
&
s
)
{
if
(
not
fs
::
exists
(
s
))
return
{};
return
from_value
<
std
::
vector
<
tuning_entry
>>
(
from_json_string
(
read_string
(
s
)));
}
static
std
::
size_t
get_tuning_for
(
const
std
::
vector
<
shape
>&
inputs
)
{
static
auto
tuning
=
read_tuning
(
string_value_of
(
MIGRAPHX_CK_TUNING
{},
""
));
if
(
tuning
.
empty
())
std
::
cout
<<
"*********** Warning: No CK tuning!"
<<
std
::
endl
;
auto
it
=
std
::
find_if
(
tuning
.
begin
(),
tuning
.
end
(),
[
&
](
const
auto
&
p
)
{
return
p
.
first
==
inputs
;
});
if
(
it
==
tuning
.
end
())
{
std
::
cout
<<
"*********** Warning: CK tuning missing for config!"
<<
std
::
endl
;
return
6
;
}
return
it
->
second
;
}
static
std
::
size_t
get_batch_stride
(
const
shape
&
s
)
{
return
s
.
strides
()[
s
.
strides
().
size
()
-
3
];
}
struct
ck_batched_gemm_compiler
:
compiler
<
ck_batched_gemm_compiler
>
{
static
std
::
string
get_layout
(
const
shape
&
s
)
{
return
s
.
transposed
()
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
}
static
std
::
string
get_type
(
const
shape
&
s
)
{
if
(
s
.
type
()
==
shape
::
half_type
)
return
"ck::half_t"
;
return
shape
::
cpp_type
(
s
.
type
());
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_batched_gemm"
,
"gpu::ck_batched_gemm"
};
}
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
auto
a_shape
=
inputs
[
0
];
auto
b_shape
=
inputs
[
1
];
auto
c_shape
=
inputs
[
2
];
auto
m
=
c_shape
.
lens
().
front
();
auto
n
=
c_shape
.
lens
().
back
();
auto
k
=
a_shape
.
lens
().
back
();
auto
i
=
v
.
get
(
"tuning_val"
,
get_tuning_for
(
inputs
));
auto
&
instance
=
get_instance
(
i
,
[
&
](
const
auto
&
x
)
->
bool
{
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
get_layout
(
c_shape
)
==
x
[
3
]
and
get_type
(
a_shape
)
==
x
[
4
]
and
get_type
(
b_shape
)
==
x
[
5
]
and
get_type
(
c_shape
)
==
x
[
9
];
});
const
bool
pad_m
=
m
%
8
;
const
bool
pad_n
=
n
%
8
;
const
bool
pad_k
=
k
%
8
;
if
(
pad_m
or
pad_n
or
pad_k
)
{
std
::
string
padding_t
=
"ck::tensor_operation::device::GemmSpecialization::"
;
padding_t
+=
pad_m
?
"M"
:
""
;
padding_t
+=
pad_n
?
"N"
:
""
;
padding_t
+=
pad_k
?
"K"
:
""
;
padding_t
+=
"Padding"
;
set_padding
(
instance
,
padding_t
);
}
hip_compile_options
options
;
// batch_count
auto
out_lens
=
c_shape
.
lens
();
auto
batch_count
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
auto
batchStrideA
=
get_batch_stride
(
a_shape
);
auto
batchStrideB
=
get_batch_stride
(
b_shape
);
auto
batchStrideC
=
get_batch_stride
(
c_shape
);
options
.
params
+=
" -DBATCH_COUNT="
+
std
::
to_string
(
batch_count
);
options
.
params
+=
" -DBATCHSTRIDEA="
+
std
::
to_string
(
batchStrideA
);
options
.
params
+=
" -DBATCHSTRIDEB="
+
std
::
to_string
(
batchStrideB
);
options
.
params
+=
" -DBATCHSTRIDEC="
+
std
::
to_string
(
batchStrideC
);
std
::
cout
<<
"Batch_count: "
<<
batch_count
<<
std
::
endl
;
std
::
cout
<<
"BatchStrideA: "
<<
batchStrideA
<<
std
::
endl
;
std
::
cout
<<
"BatchStrideB: "
<<
batchStrideB
<<
std
::
endl
;
std
::
cout
<<
"BatchStrideC: "
<<
batchStrideC
<<
std
::
endl
;
auto
block_size
=
get_block_size
(
instance
);
auto
grid_size
=
batch_count
*
get_grid_size
(
instance
,
m
,
n
);
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
inputs
=
inputs
;
options
.
output
=
c_shape
;
options
.
kernel_name
=
"ck_batched_gemm_kernel"
;
options
.
virtual_inputs
=
inputs
;
auto
src
=
interpolate_string
(
ck_batched_gemm_kernel
,
{{
"instance"
,
join_strings
(
instance
,
","
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
shapes
=
to_shapes
(
ins
->
inputs
());
return
action_decorate
(
replace
(
compile_op
(
ctx
,
shapes
,
op
.
to_value
())),
[
=
]
{
if
(
enabled
(
MIGRAPHX_LOG_CK_GEMM
{}))
std
::
cout
<<
"ck_batched_gemm: "
<<
to_json_string
(
to_value
(
shapes
))
<<
std
::
endl
;
});
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
View file @
e72ecc75
...
...
@@ -110,6 +110,15 @@ constexpr F for_each(Iterator first, Iterator last, F f)
return
f
;
}
template
<
class
Iterator
,
class
T
>
constexpr
void
fill
(
Iterator
first
,
Iterator
last
,
const
T
&
val
)
{
while
(
first
!=
last
)
{
*
first
=
val
;
++
first
;
}
}
template
<
class
Iterator
,
class
Predicate
>
constexpr
Iterator
find_if
(
Iterator
first
,
Iterator
last
,
Predicate
p
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
View file @
e72ecc75
...
...
@@ -59,6 +59,15 @@ constexpr auto to_ck_tensor()
});
}
template
<
class
Tensor
>
constexpr
auto
to_ck_batched_tensor
()
{
constexpr
auto
s
=
get_shape_c
<
Tensor
>
{};
constexpr
auto
sz
=
s
.
lens
.
size
();
return
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
s
.
lens
[
sz
-
2
],
s
.
lens
[
sz
-
1
]),
ck
::
make_tuple
(
s
.
strides
[
sz
-
2
],
s
.
strides
[
sz
-
1
]));
}
template
<
class
F
>
struct
ck_function_adaptor
:
F
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_batched_gemm.hpp
0 → 100644
View file @
e72ecc75
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_BATCHED_GEMM_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_BATCHED_GEMM_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ck.hpp>
#include <migraphx/kernels/ck_batched_gemm_includes.hpp>
#include <migraphx/kernels/shape.hpp>
namespace
migraphx
{
template
<
class
T0
,
class
T1
,
class
T2
,
class
T3
>
struct
ck_batched_gemm_settings
{
T0
batch_count
{};
T1
batchStrideA
{};
T2
batchStrideB
{};
T3
batchStrideC
{};
};
template
<
class
...
Ts
>
constexpr
ck_batched_gemm_settings
<
Ts
...
>
make_ck_batched_gemm_settings
(
Ts
...
xs
)
{
return
{
xs
...};
}
template
<
ck
::
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
{
__device__
ComputePtrOffsetOfStridedBatch
(
ck
::
index_t
BatchStrideA
,
ck
::
index_t
BatchStrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
ck
::
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
}
__host__
__device__
constexpr
ck
::
long_index_t
GetAPtrOffset
(
ck
::
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
ck
::
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
ck
::
long_index_t
GetBPtrOffset
(
ck
::
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
ck
::
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
ck
::
index_t
g_idx
)
const
{
std
::
array
<
ck
::
long_index_t
,
NumDTensor
>
ds_offset
;
ck
::
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_offset
[
i
]
=
g_idx
*
static_cast
<
ck
::
long_index_t
>
(
BatchStrideDs_
[
i
]);
});
return
ds_offset
;
}
__host__
__device__
constexpr
ck
::
long_index_t
GetEPtrOffset
(
ck
::
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
ck
::
long_index_t
>
(
BatchStrideE_
);
}
private:
ck
::
index_t
BatchStrideA_
;
ck
::
index_t
BatchStrideB_
;
std
::
array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
ck
::
index_t
BatchStrideE_
;
};
template
<
class
G
,
class
Settings
,
class
A
,
class
B
,
class
E
,
class
...
Ds
>
__device__
void
ck_batched_gemm
(
Settings
s
,
A
a
,
B
b
,
E
e
,
Ds
...
ds
)
{
constexpr
const
G
gemm
{};
constexpr
const
auto
a_grid_desc_m_k
=
gemm
.
matrix_padder
.
PadADescriptor_M_K
(
to_ck_batched_tensor
<
A
>
());
constexpr
const
auto
b_grid_desc_n_k
=
gemm
.
matrix_padder
.
PadBDescriptor_N_K
(
to_ck_batched_tensor
<
B
>
());
constexpr
const
auto
e_grid_desc_m_n
=
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_batched_tensor
<
E
>
());
constexpr
const
auto
ds_grid_desc_m_n
=
ck
::
make_tuple
(
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_batched_tensor
<
Ds
>
())...);
constexpr
const
auto
block_2_etile_map
=
gemm
.
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
);
using
GridwiseGemm
=
typename
G
::
GridwiseGemm
;
// tensor descriptors for block/thread-wise copy
constexpr
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
);
constexpr
auto
b_grid_desc_bk0_n_bk1
=
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
constexpr
auto
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
);
constexpr
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
constexpr
const
bool
HasMainKBlockLoop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
0
>
{})
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
2
>
{}));
static
constexpr
ck
::
index_t
NumDTensor
=
gemm
.
NumDTensor
;
std
::
array
<
ck
::
index_t
,
NumDTensor
>
batchStrideDs
;
ck
::
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
batchStrideDs
[
i
]
=
s
.
batchStrideC
;
});
const
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch
{
s
.
batchStrideA
,
s
.
batchStrideB
,
batchStrideDs
,
s
.
batchStrideC
};
auto
batch_count
=
s
.
batch_count
;
const
ck
::
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
ck
::
get_grid_size
()
/
batch_count
);
const
ck
::
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
ck
::
get_block_1d_id
()
/
num_blocks_per_batch
);
const
ck
::
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
ck
::
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
ck
::
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
ck
::
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
ck
::
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
ck
::
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
p_ds_grid_grp
=
ck
::
make_tuple
(
ds
.
data
()...);
ck
::
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid_grp
[
i
]
+
ds_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
a
.
data
()
+
a_batch_offset
,
b
.
data
()
+
b_batch_offset
,
p_ds_grid_grp
,
e
.
data
()
+
e_batch_offset
,
p_shared
,
gemm
.
a_element_op
,
gemm
.
b_element_op
,
gemm
.
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
}
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/ck_batched_gemm_includes.hpp
0 → 100644
View file @
e72ecc75
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_BG_INCLUDES_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_BG_INCLUDES_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
// #include <ck/utility/common_header.hpp>
// #include <ck/tensor_description/tensor_descriptor.hpp>
// #include <ck/tensor_description/tensor_descriptor_helper.hpp>
// #include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
// #include <ck/tensor_operation/gpu/device/device_gemm.hpp>
// #include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
// #include <ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp>
// #include <ck/tensor_operation/gpu/device/matrix_padder.hpp>
// #include <ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
namespace
migraphx
{
template
<
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
struct
BlockToCTileMap_M00_N0_M01Adapt
{
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
static
constexpr
auto
I2
=
ck
::
Number
<
2
>
{};
static
constexpr
auto
I3
=
ck
::
Number
<
3
>
{};
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
ck
::
index_t
M01
=
8
)
:
M01_
(
M01
),
c_grid_desc_m_n_
(
c_grid_desc_m_n
)
{
}
__host__
__device__
constexpr
ck
::
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
const
auto
M0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
ck
::
index_t
grid_size
=
M0
*
N0
;
return
grid_size
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I1
),
NPerBlock
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
ck
::
index_t
idx_N0
=
block_1d_id
%
N0
;
ck
::
index_t
idx_M0
=
block_1d_id
/
N0
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0
%
M01_
)
?
M01_
:
M0
%
M01_
;
ck
::
index_t
idx_M00
=
idx_M0
/
M01_
;
ck
::
index_t
idx_M01
=
idx_M0
%
M01_
;
ck
::
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
return
ck
::
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
constexpr
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
__host__
__device__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
private:
ck
::
index_t
M01_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumGemmKPrefetchStage
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_AK1
,
ck
::
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_BK1
,
ck
::
index_t
BBlockLdsExtraN
,
ck
::
index_t
CShuffleMXdlPerWavePerShuffle
,
ck
::
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
ck
::
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
ck
::
LoopScheduler
LoopSched
=
ck
::
make_default_loop_scheduler
()>
struct
CK_DeviceBatchedGemmMultipleD
{
ck
::
tensor_operation
::
device
::
MatrixPadder
<
GemmSpec
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
>
matrix_padder
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// GridwiseGemm
using
GridwiseGemm
=
ck
::
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
ck
::
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
// return block_id to E matrix tile idx (m0, n0) mapping
template
<
class
EGridDesc_M_N
>
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n_
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
e_grid_desc_m_n_
);
}
static
constexpr
ck
::
index_t
NumDTensor
=
DsDataType
::
Size
();
AElementwiseOperation
a_element_op
{};
BElementwiseOperation
b_element_op
{};
CDEElementwiseOperation
cde_element_op
{};
};
}
// namespace migraphx
#endif
test/verify/0ck_batched_gemm.cpp
0 → 100644
View file @
e72ecc75
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
ck_batched_gemm
:
verify_program
<
ck_batched_gemm
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
size_t
b
=
2
;
std
::
size_t
m
=
3
;
std
::
size_t
n
=
3
;
std
::
size_t
k
=
3
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
b
,
m
,
k
}};
std
::
vector
<
float
>
v1
(
b
*
m
*
k
,
1
);
std
::
vector
<
float
>
v2
(
b
*
k
*
n
,
1
);
//{1, 2, 3, 4, 5, 6, 7, 8};
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m1_shape);
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
m1_shape
,
v1
});
auto
l2
=
mm
->
add_literal
(
migraphx
::
literal
{
m1_shape
,
v1
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
l1
,
l2
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment