Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9c91c08d
Unverified
Commit
9c91c08d
authored
Jul 07, 2023
by
Chris Austen
Committed by
GitHub
Jul 07, 2023
Browse files
Merge branch 'develop' into enable_navi_32_ci
parents
a56bb11d
c1b8c975
Changes
124
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
871 additions
and
16 deletions
+871
-16
src/targets/gpu/include/migraphx/gpu/device_name.hpp
src/targets/gpu/include/migraphx/gpu/device_name.hpp
+4
-0
src/targets/gpu/include/migraphx/gpu/fuse_ck.hpp
src/targets/gpu/include/migraphx/gpu/fuse_ck.hpp
+48
-0
src/targets/gpu/include/migraphx/gpu/lowering.hpp
src/targets/gpu/include/migraphx/gpu/lowering.hpp
+2
-2
src/targets/gpu/include/migraphx/gpu/time_op.hpp
src/targets/gpu/include/migraphx/gpu/time_op.hpp
+0
-2
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+456
-0
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+1
-1
src/targets/gpu/jit/gather.cpp
src/targets/gpu/jit/gather.cpp
+1
-1
src/targets/gpu/jit/gathernd.cpp
src/targets/gpu/jit/gathernd.cpp
+1
-1
src/targets/gpu/jit/layernorm.cpp
src/targets/gpu/jit/layernorm.cpp
+1
-1
src/targets/gpu/jit/pad.cpp
src/targets/gpu/jit/pad.cpp
+1
-1
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+1
-1
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+1
-1
src/targets/gpu/jit/roialign.cpp
src/targets/gpu/jit/roialign.cpp
+1
-1
src/targets/gpu/jit/scatternd.cpp
src/targets/gpu/jit/scatternd.cpp
+1
-1
src/targets/gpu/jit/softmax.cpp
src/targets/gpu/jit/softmax.cpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+12
-0
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
+164
-0
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+72
-0
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+11
-2
src/targets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
...ets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
+92
-0
No files found.
src/targets/gpu/include/migraphx/gpu/device_name.hpp
View file @
9c91c08d
...
...
@@ -27,10 +27,14 @@
#include <migraphx/config.hpp>
#include <string>
struct
hipDeviceProp_t
;
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
std
::
string
get_arch_name
(
const
hipDeviceProp_t
&
props
);
std
::
string
get_device_name
();
int
get_device_id
();
...
...
src/targets/gpu/include/migraphx/gpu/fuse_ck.hpp
0 → 100644
View file @
9c91c08d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
#define MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module_pass_manager
;
namespace
gpu
{
struct
fuse_ck
{
context
*
ctx
=
nullptr
;
std
::
string
name
()
const
{
return
"gpu::fuse_ck"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
src/targets/gpu/include/migraphx/gpu/lowering.hpp
View file @
9c91c08d
...
...
@@ -30,7 +30,7 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
struct
module
_pass_manager
;
namespace
gpu
{
...
...
@@ -45,7 +45,7 @@ struct lowering
context
*
ctx
;
bool
offload_copy
;
std
::
string
name
()
const
{
return
"gpu::lowering"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
_pass_manager
&
mp
m
)
const
;
};
}
// namespace gpu
...
...
src/targets/gpu/
driver/
include/migraphx/gpu/
driver/perf
.hpp
→
src/targets/gpu/include/migraphx/gpu/
time_op
.hpp
View file @
9c91c08d
...
...
@@ -31,12 +31,10 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
driver
{
std
::
pair
<
double
,
double
>
time_op
(
context
&
ictx
,
operation
op
,
const
std
::
vector
<
shape
>&
inputs
,
int
n
=
100
);
}
// namespace driver
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/targets/gpu/jit/ck_gemm.cpp
0 → 100644
View file @
9c91c08d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <fstream>
#include <migraphx/filesystem.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include "ck/host/device_gemm_multiple_d.hpp"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_LOG_CK_GEMM
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_TUNING
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_TUNING_VALUE
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_DEBUG
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TUNE_CK
);
// NOLINTNEXTLINE
static
const
char
*
const
ck_gemm_kernel
=
R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_gemm.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/ops.hpp>
#include <${include}>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm<${solution}, ${blocks_per_batch}>(xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
// NOLINTNEXTLINE
static
const
char
*
const
disable_warning_pragma
=
R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__"
;
template
<
class
P
>
static
std
::
string
ck_disable_warnings
(
P
p
)
{
return
interpolate_string
(
disable_warning_pragma
,
{{
"content"
,
std
::
string
{
p
.
first
,
p
.
second
}}});
}
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
create_ck_header_strings
()
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
result
;
auto
ck_headers
=
ck
::
host
::
GetHeaders
();
std
::
transform
(
ck_headers
.
begin
(),
ck_headers
.
end
(),
std
::
inserter
(
result
,
result
.
begin
()),
[
&
](
auto
&&
p
)
{
return
std
::
make_pair
(
p
.
first
,
ck_disable_warnings
(
p
.
second
));
});
return
result
;
}
static
std
::
vector
<
src_file
>
create_ck_headers
()
{
static
const
auto
&
header_strings
=
create_ck_header_strings
();
std
::
vector
<
src_file
>
srcs
;
std
::
transform
(
header_strings
.
begin
(),
header_strings
.
end
(),
std
::
back_inserter
(
srcs
),
[
&
](
auto
&&
p
)
{
return
src_file
{
fs
::
path
{
p
.
first
},
{
p
.
second
.
data
(),
p
.
second
.
data
()
+
p
.
second
.
size
()}};
});
return
srcs
;
}
static
const
std
::
vector
<
src_file
>&
ck_headers
()
{
static
const
auto
&
headers
=
create_ck_headers
();
return
headers
;
}
static
bool
transposed_matrix
(
const
shape
&
s
)
{
return
s
.
strides
().
back
()
!=
1
;
}
using
tuning_entry
=
std
::
pair
<
std
::
vector
<
shape
>
,
size_t
>
;
static
std
::
vector
<
tuning_entry
>
read_tuning
(
const
std
::
string
&
s
)
{
if
(
not
fs
::
exists
(
s
))
return
{};
return
from_value
<
std
::
vector
<
tuning_entry
>>
(
from_json_string
(
read_string
(
s
)));
}
static
float
matrix_distance
(
const
shape
&
x
,
const
shape
&
y
)
{
if
(
x
.
type
()
!=
y
.
type
())
return
std
::
numeric_limits
<
float
>::
max
();
if
(
transposed_matrix
(
x
)
!=
transposed_matrix
(
y
))
return
std
::
numeric_limits
<
float
>::
max
();
auto
sum_squared
=
std
::
inner_product
(
x
.
lens
().
rbegin
(),
x
.
lens
().
rbegin
()
+
2
,
y
.
lens
().
rbegin
(),
0
,
std
::
plus
<>
{},
[](
auto
a
,
auto
b
)
{
return
(
a
-
b
)
*
(
a
-
b
);
});
return
std
::
sqrt
(
sum_squared
);
}
static
std
::
size_t
get_tuning_for
(
const
std
::
vector
<
shape
>&
inputs
)
{
static
auto
tuning
=
read_tuning
(
string_value_of
(
MIGRAPHX_CK_TUNING
{},
""
));
if
(
tuning
.
empty
())
{
std
::
cout
<<
"*********** Warning: No CK tuning! for config:"
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
0
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
1
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
2
]
<<
std
::
endl
;
}
auto
it
=
std
::
find_if
(
tuning
.
begin
(),
tuning
.
end
(),
[
&
](
const
auto
&
p
)
{
return
p
.
first
==
inputs
;
});
if
(
it
==
tuning
.
end
())
{
std
::
cout
<<
"*********** Warning: CK tuning missing for config!"
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
0
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
1
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
2
]
<<
std
::
endl
;
std
::
vector
<
std
::
pair
<
float
,
std
::
size_t
>>
w
;
std
::
transform
(
tuning
.
begin
(),
tuning
.
end
(),
std
::
back_inserter
(
w
),
[
&
](
const
auto
&
p
)
{
if
(
inputs
.
size
()
<
3
or
p
.
first
.
size
()
<
3
)
MIGRAPHX_THROW
(
"Invalid CK config"
);
auto
avg_distance
=
std
::
inner_product
(
p
.
first
.
begin
(),
p
.
first
.
begin
()
+
3
,
inputs
.
begin
(),
0.0
f
,
std
::
plus
<>
{},
[](
const
auto
&
x
,
const
auto
&
y
)
{
return
matrix_distance
(
x
,
y
)
/
3.0
f
;
});
return
std
::
make_pair
(
avg_distance
,
p
.
second
);
});
std
::
sort
(
w
.
begin
(),
w
.
end
());
std
::
size_t
default_value
=
4
;
if
(
not
w
.
empty
())
default_value
=
w
.
front
().
second
;
auto
tuning_val
=
value_of
(
MIGRAPHX_CK_TUNING_VALUE
{},
default_value
);
std
::
cout
<<
"*********** Warning: CK try tuning: "
<<
tuning_val
<<
std
::
endl
;
return
tuning_val
;
}
return
it
->
second
;
}
struct
ck_gemm_compiler
:
compiler
<
ck_gemm_compiler
>
{
static
std
::
string
get_layout
(
const
shape
&
s
)
{
return
transposed_matrix
(
s
)
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
}
static
ck
::
host
::
DataType
get_type
(
const
shape
&
s
)
{
if
(
s
.
type
()
==
shape
::
half_type
)
return
ck
::
host
::
DataType
::
Half
;
else
if
(
s
.
type
()
==
shape
::
float_type
)
return
ck
::
host
::
DataType
::
Float
;
else
if
(
s
.
type
()
==
shape
::
int8_type
)
return
ck
::
host
::
DataType
::
Int8
;
else
if
(
s
.
type
()
==
shape
::
int32_type
)
return
ck
::
host
::
DataType
::
Int32
;
MIGRAPHX_THROW
(
"Unsupported ck type"
);
}
template
<
class
Iterator
,
class
F
>
static
std
::
string
ck_tuple
(
Iterator
start
,
Iterator
last
,
F
f
)
{
std
::
vector
<
std
::
string
>
s
;
std
::
transform
(
start
,
last
,
std
::
back_inserter
(
s
),
f
);
return
"ck::Tuple<"
+
join_strings
(
s
,
","
)
+
">"
;
}
static
std
::
vector
<
shape
>
adjust_inputs
(
std
::
vector
<
shape
>
inputs
,
bool
&
swap_inputs
)
{
swap_inputs
=
false
;
auto
c_shape
=
inputs
.
back
();
if
(
not
transposed_matrix
(
c_shape
))
return
inputs
;
std
::
vector
<
int64_t
>
perm
(
c_shape
.
lens
().
size
());
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
0
);
std
::
swap
(
perm
[
perm
.
size
()
-
1
],
perm
[
perm
.
size
()
-
2
]);
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
shape
s
)
{
return
reorder_shape
(
s
,
perm
);
});
swap_inputs
=
true
;
return
inputs
;
}
static
std
::
size_t
get_batch_count
(
const
shape
&
s
)
{
return
std
::
accumulate
(
s
.
lens
().
rbegin
()
+
2
,
s
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
static
void
fold_batch_dims
(
shape
&
s
)
{
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
return
;
auto
batch_count
=
get_batch_count
(
s
);
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
if
(
transposed_matrix
(
s
))
s
=
shape
{
s
.
type
(),
{
m1
,
m2
*
batch_count
}};
else
s
=
shape
{
s
.
type
(),
{
m1
*
batch_count
,
m2
}};
}
static
void
remove_batch_dims
(
shape
&
s
)
{
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
return
;
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
,
"gpu::ck_gemm"
};
}
static
bool
standard_batch
(
const
shape
&
s
)
{
if
(
s
.
lens
().
size
()
<
3
)
return
true
;
std
::
vector
<
std
::
size_t
>
lens
(
s
.
lens
().
begin
(),
s
.
lens
().
end
()
-
2
);
std
::
vector
<
std
::
size_t
>
strides
(
s
.
strides
().
begin
(),
s
.
strides
().
end
()
-
2
);
auto
base
=
*
(
s
.
lens
().
end
()
-
2
)
*
*
(
s
.
lens
().
end
()
-
1
);
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
stride
)
{
return
stride
/
base
;
});
return
shape
{
s
.
type
(),
lens
,
strides
}.
standard
();
}
bool
can_fold_batch
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
const
auto
&
b_shape
=
inputs
[
1
];
if
(
std
::
any_of
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
[](
auto
input
)
{
return
not
standard_batch
(
input
);
}))
return
false
;
const
auto
&
b_strides
=
b_shape
.
strides
();
return
std
::
all_of
(
b_strides
.
begin
(),
b_strides
.
end
()
-
2
,
[](
auto
stride
)
{
return
stride
==
0
;
});
}
ck
::
host
::
device_gemm_multiple_d
::
Problem
create_problem
(
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
const
auto
&
a_shape
=
inputs
[
0
];
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
c_shape
=
inputs
.
back
();
auto
rank
=
a_shape
.
lens
().
size
();
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
m
=
c_shape
.
lens
()[
rank
-
2
];
m
=
can_fold_batch
(
inputs
)
?
m
*
batch_count
:
m
;
auto
n
=
c_shape
.
lens
().
back
();
auto
k
=
a_shape
.
lens
().
back
();
const
bool
trans_a
=
transposed_matrix
(
a_shape
);
const
bool
trans_b
=
transposed_matrix
(
b_shape
);
const
bool
trans_e
=
transposed_matrix
(
c_shape
);
const
auto
a_type
=
get_type
(
a_shape
);
const
auto
b_type
=
get_type
(
b_shape
);
const
auto
e_type
=
get_type
(
c_shape
);
std
::
vector
<
bool
>
ds_layout
;
std
::
transform
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
std
::
back_inserter
(
ds_layout
),
[](
const
auto
&
i
)
{
return
transposed_matrix
(
i
);
});
std
::
vector
<
ck
::
host
::
DataType
>
ds_type
;
std
::
transform
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
std
::
back_inserter
(
ds_type
),
[](
const
auto
&
i
)
{
return
get_type
(
i
);
});
std
::
string
ck_passthrough
=
"ck_passthrough"
;
std
::
string
cde_op
=
ck_passthrough
;
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
if
(
v
.
contains
(
"post"
))
{
cde_op
=
v
.
at
(
"post"
).
to
<
std
::
string
>
();
}
return
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
m
,
n
,
k
,
trans_a
,
trans_b
,
trans_e
,
ds_layout
,
a_type
,
b_type
,
e_type
,
ds_type
,
ck_passthrough
,
ck_passthrough
,
cde_op
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
const
auto
&
a_shape
=
inputs
[
0
];
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
c_shape
=
inputs
.
back
();
auto
tuning_value
=
v
.
get
(
"tuning_value"
,
4
);
if
(
not
v
.
contains
(
"tuning_value"
))
tuning_value
=
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
});
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
problem
=
create_problem
(
inputs
,
v
);
const
auto
include_header
=
problem
.
GetIncludeHeader
();
const
auto
solutions
=
problem
.
GetSolutions
(
ctx
.
get_current_device
().
get_gfx_name
());
const
auto
&
solution
=
solutions
.
at
(
tuning_value
);
const
auto
template_str
=
solution
.
template_str
;
const
auto
blocks_per_batch
=
solution
.
grid_size
;
const
auto
block_size
=
solution
.
block_size
;
hip_compile_options
options
;
options
.
additional_src_files
=
ck_headers
();
auto
grid_size
=
can_fold_batch
(
inputs
)
?
blocks_per_batch
:
batch_count
*
blocks_per_batch
;
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
inputs
=
inputs
;
options
.
output
=
c_shape
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"ck_gemm_kernel"
);
options
.
virtual_inputs
=
inputs
;
if
(
can_fold_batch
(
inputs
))
{
auto
vinputs
=
inputs
;
fold_batch_dims
(
vinputs
[
0
]);
remove_batch_dims
(
vinputs
[
1
]);
std
::
for_each
(
vinputs
.
begin
()
+
2
,
vinputs
.
end
(),
fold_batch_dims
);
options
.
virtual_inputs
=
vinputs
;
}
if
(
v
.
get
(
"check"
,
false
)
or
enabled
(
MIGRAPHX_CK_DEBUG
{}))
options
.
params
+=
" -DMIGRAPHX_CK_CHECK=1"
;
auto
src
=
interpolate_string
(
ck_gemm_kernel
,
{{
"solution"
,
template_str
},
{
"include"
,
include_header
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"kernel"
,
options
.
kernel_name
}});
return
compile_hip_code_object
(
src
,
options
);
}
value
create_settings
(
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
v
=
op
.
to_value
();
v
[
"kernel"
]
=
"ck_gemm_kernel"
;
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_ck_gemm_function"
)
+
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);"
;
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm>"
;
v
[
"kernel"
]
=
"ck_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
v
;
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
,
const
value
&
solution
)
const
{
auto
shapes
=
to_shapes
(
ins
->
inputs
());
auto
v
=
create_settings
(
ins
,
op
);
if
(
not
solution
.
is_null
())
v
[
"tuning_value"
]
=
solution
;
return
{
compile_op
(
ctx
,
shapes
,
v
),
[
=
](
module
&
m
,
instruction_ref
ins2
,
const
operation
&
code_object
)
{
if
(
enabled
(
MIGRAPHX_LOG_CK_GEMM
{}))
{
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
.
back
().
with_type
(
shapes
[
0
].
type
())};
std
::
cout
<<
"gpu::ck_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
<<
std
::
endl
;
}
m
.
replace_instruction
(
ins2
,
code_object
,
ins2
->
inputs
());
}};
}
optional
<
tuning_config
>
get_tuning_config
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
,
bool
exhaustive
)
const
{
if
(
not
exhaustive
and
not
enabled
(
MIGRAPHX_TUNE_CK
{}))
return
nullopt
;
tuning_config
tc
;
auto
shapes
=
to_shapes
(
ins
->
inputs
());
auto
problem
=
create_problem
(
shapes
,
create_settings
(
ins
,
op
));
auto
solutions
=
problem
.
GetSolutions
(
ctx
.
get_current_device
().
get_gfx_name
());
tc
.
solutions
.
resize
(
solutions
.
size
());
std
::
iota
(
tc
.
solutions
.
begin
(),
tc
.
solutions
.
end
(),
0
);
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
.
back
()};
tc
.
problem
=
to_value
(
gemm_shapes
);
return
tc
;
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/concat.cpp
View file @
9c91c08d
...
...
@@ -47,7 +47,7 @@ ${preamble}
extern "C" {
__global__
void ${kernel}(${params})
MIGRAPHX_GLOBAL
void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat<${axis}>(${concat_args})(${post}, y, xs...);
...
...
src/targets/gpu/jit/gather.cpp
View file @
9c91c08d
...
...
@@ -44,7 +44,7 @@ namespace migraphx {
extern "C" {
__global__
void gather_kernel(void* in_data, void* in_indices, void* output)
MIGRAPHX_GLOBAL
void gather_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
gather<${axis}>(xs...);
...
...
src/targets/gpu/jit/gathernd.cpp
View file @
9c91c08d
...
...
@@ -44,7 +44,7 @@ namespace migraphx {
extern "C" {
__global__
void gathernd_kernel(void* in_data, void* in_indices, void* output)
MIGRAPHX_GLOBAL
void gathernd_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
...
...
src/targets/gpu/jit/layernorm.cpp
View file @
9c91c08d
...
...
@@ -48,7 +48,7 @@ namespace migraphx {
${preamble}
extern "C" {
__global__
void ${kernel}(${params})
MIGRAPHX_GLOBAL
void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, ${eps}, xs...);
...
...
src/targets/gpu/jit/pad.cpp
View file @
9c91c08d
...
...
@@ -44,7 +44,7 @@ static const char* const pointwise_kernel = R"__migraphx__(
namespace migraphx {
extern "C" {
__global__
void pad_kernel(void* input_p, void* output_p)
MIGRAPHX_GLOBAL
void pad_kernel(void* input_p, void* output_p)
{
auto offsets = index_ints<${offsets}>{};
auto idx = make_index();
...
...
src/targets/gpu/jit/pointwise.cpp
View file @
9c91c08d
...
...
@@ -44,7 +44,7 @@ namespace migraphx {
${preamble}
extern "C" {
__global__
void ${kernel}(${params})
MIGRAPHX_GLOBAL
void ${kernel}(${params})
{
auto idx = make_index();
pointwise(idx, ${transformers})(${lambda}, ${args});
...
...
src/targets/gpu/jit/reduce.cpp
View file @
9c91c08d
...
...
@@ -45,7 +45,7 @@ namespace migraphx {
${preamble}
extern "C" {
__global__
void reduce_kernel(void* input_p, void* output_p)
MIGRAPHX_GLOBAL
void reduce_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
...
...
src/targets/gpu/jit/roialign.cpp
View file @
9c91c08d
...
...
@@ -41,7 +41,7 @@ namespace migraphx {
extern "C" {
__global__
void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y)
MIGRAPHX_GLOBAL
void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y)
{
make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) {
auto settings = make_roalign_settings(MIGRAPHX_MAKE_CONSTANT(float{ROIS_OFFSET}),
...
...
src/targets/gpu/jit/scatternd.cpp
View file @
9c91c08d
...
...
@@ -42,7 +42,7 @@ namespace migraphx {
extern "C" {
__global__
void scatternd_kernel(void* in_indices, void* in_updates, void* output)
MIGRAPHX_GLOBAL
void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., ${reduction}{});
...
...
src/targets/gpu/jit/softmax.cpp
View file @
9c91c08d
...
...
@@ -45,7 +45,7 @@ static const char* const softmax_kernel = R"__migraphx__(
namespace migraphx {
extern "C" {
__global__
void softmax_kernel(void* input_p, void* output_p)
MIGRAPHX_GLOBAL
void softmax_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
softmax<${axis}>(input, output);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
9c91c08d
...
...
@@ -272,6 +272,18 @@ struct integral_const_array : array<T, sizeof...(Xs)>
MIGRAPHX_DEVICE_CONSTEXPR
integral_const_array
()
:
base_array
({
Xs
...})
{}
};
template
<
class
T
,
class
...
Ts
>
constexpr
auto
make_const_array
(
T
x
,
Ts
...
xs
)
{
return
integral_const_array
<
typename
T
::
value_type
,
x
,
xs
...
>
{};
}
template
<
class
T
,
T
...
Xs
,
class
F
>
constexpr
auto
unpack
(
integral_const_array
<
T
,
Xs
...
>
,
F
f
)
{
return
f
(
_c
<
Xs
>
...);
}
template
<
class
T
,
T
...
Xs
,
class
F
>
constexpr
auto
transform
(
integral_const_array
<
T
,
Xs
...
>
,
F
f
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
0 → 100644
View file @
9c91c08d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_HPP
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <ck/utility/common_header.hpp>
#include <ck/tensor_description/tensor_descriptor.hpp>
#include <ck/tensor_description/tensor_descriptor_helper.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
namespace
migraphx
{
namespace
detail
{
template
<
class
T
>
struct
to_ck_type_impl
{
using
type
=
T
;
};
template
<
>
struct
to_ck_type_impl
<
migraphx
::
half
>
{
using
type
=
ck
::
half_t
;
};
template
<
class
T
>
struct
to_ck_type_impl
<
const
T
>
{
using
type
=
const
typename
to_ck_type_impl
<
T
>::
type
;
};
template
<
class
Shape
>
constexpr
bool
is_row_major
()
{
constexpr
auto
strides
=
Shape
{}.
strides
;
MIGRAPHX_ASSERT
(
strides
.
size
()
>=
2
);
if
(
strides
.
back
()
==
1
)
{
MIGRAPHX_ASSERT
(
not
Shape
{}.
is_transposed
());
return
true
;
}
MIGRAPHX_ASSERT
(
strides
[
strides
.
size
()
-
2
]
==
1
);
return
false
;
}
}
// namespace detail
template
<
class
T
>
using
to_ck_type
=
typename
detail
::
to_ck_type_impl
<
T
>::
type
;
template
<
class
T
>
constexpr
auto
to_ck_pointer
(
T
*
x
)
{
return
static_cast
<
to_ck_type
<
T
>*>
(
x
);
}
template
<
class
T
>
constexpr
auto
to_ck_const_pointer
(
const
T
*
x
)
{
return
static_cast
<
const
to_ck_type
<
T
>*>
(
x
);
}
template
<
class
Shape
>
using
to_ck_gemm_layout
=
conditional_t
<
detail
::
is_row_major
<
get_shape_c
<
Shape
>>
(),
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
template
<
class
Tensor
>
constexpr
auto
to_ck_tensor
()
{
constexpr
auto
s
=
get_shape_c
<
Tensor
>
{};
return
sequence
(
s
.
lens
.
size
(),
[
&
](
auto
...
is
)
{
return
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
s
.
lens
[
is
]...),
ck
::
make_tuple
(
s
.
strides
[
is
]...));
});
}
template
<
class
F
>
struct
ck_function_adaptor
:
F
{
template
<
class
...
Ts
>
constexpr
ck_function_adaptor
(
Ts
&&
...
xs
)
:
F
(
static_cast
<
Ts
&&>
(
xs
)...)
{
}
template
<
class
T
,
class
...
Ts
>
constexpr
void
operator
()(
T
&
out
,
Ts
&&
...
xs
)
const
{
out
=
static_cast
<
const
F
&>
(
*
this
)(
static_cast
<
Ts
&&>
(
xs
)...);
}
};
struct
ck_nop
{
template
<
class
T
>
constexpr
void
operator
()(
T
&
)
const
{
}
};
struct
ck_passthrough
{
template
<
class
T
,
class
U
>
constexpr
void
operator
()(
T
&
y
,
U
x
)
const
{
y
=
x
;
}
};
struct
ck_scale
{
constexpr
ck_scale
(
float
s
)
:
scale
(
s
)
{}
template
<
class
T
,
class
U
>
constexpr
void
operator
()(
T
&
y
,
U
x
)
const
{
y
=
x
*
static_cast
<
U
>
(
scale
);
}
float
scale
;
};
struct
ck_add
{
template
<
class
T
,
class
U
>
constexpr
void
operator
()(
T
&
y
,
U
x
)
const
{
y
+=
x
;
}
};
#ifdef MIGRAPHX_CK_CHECK
#define MIGRAPHX_CK_STATIC_ASSERT static_assert
#else
#define MIGRAPHX_CK_STATIC_ASSERT(...)
#endif
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CK_HPP
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
0 → 100644
View file @
9c91c08d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ck.hpp>
#include <migraphx/kernels/gemm_batcher.hpp>
namespace
migraphx
{
// In CK, the B matrix is ordered as N,K instead of K,N
template
<
class
Dims
>
constexpr
auto
ck_transposeb_dims
(
Dims
dims
)
{
return
unpack
(
dims
,
[](
auto
k
,
auto
n
)
{
return
make_const_array
(
n
,
k
);
});
}
template
<
class
Tensor
>
using
ck_transposeb
=
decltype
(
make_shape
(
ck_transposeb_dims
(
get_shape_c
<
Tensor
>
{}.
lens
),
ck_transposeb_dims
(
get_shape_c
<
Tensor
>
{}.
strides
)));
template
<
class
G
,
class
E
,
class
A
,
class
B
,
class
...
Ds
>
__device__
void
ck_gemm_matrix
(
E
e
,
A
a
,
B
b
,
Ds
...
ds
)
{
constexpr
auto
desc
=
G
::
make_descriptor
(
to_ck_tensor
<
A
>
(),
to_ck_tensor
<
ck_transposeb
<
B
>>
(),
ck
::
make_tuple
(
to_ck_tensor
<
Ds
>
()...),
to_ck_tensor
<
E
>
());
static_assert
(
desc
.
IsValid
(),
"Invalid ck gemm."
);
G
::
Run
(
desc
,
to_ck_const_pointer
(
a
.
data
()),
to_ck_const_pointer
(
b
.
data
()),
ck
::
make_tuple
(
to_ck_const_pointer
(
ds
.
data
())...),
to_ck_pointer
(
e
.
data
()));
}
template
<
class
G
,
index_int
BlocksPerBatch
,
class
...
Ts
>
__device__
void
ck_gemm
(
Ts
...
xs
)
{
gemm_batch_args
(
make_index
(),
_c
<
BlocksPerBatch
>
,
xs
...)(
[](
auto
...
ys
)
{
ck_gemm_matrix
<
G
>
(
ys
...);
});
}
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
9c91c08d
...
...
@@ -32,8 +32,17 @@
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lifts_xs)>(private_lifts_xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT_CLASS(name, ...) \
struct name \
{ \
template <class... PrivateLiftTs> \
constexpr auto operator()(PrivateLiftTs&&... private_lifts_xs) const MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lifts_xs)>(private_lifts_xs)...)) \
}
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
0 → 100644
View file @
9c91c08d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
#define MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/index.hpp>
namespace
migraphx
{
template
<
class
Tensor
>
constexpr
auto
gemm_get_batches
()
{
constexpr
auto
lens
=
get_shape_c
<
Tensor
>
{}.
lens
;
constexpr
auto
strides
=
get_shape_c
<
Tensor
>
{}.
strides
;
constexpr
auto
new_lens
=
sequence
(
lens
.
size
()
-
_c
<
2
>
,
[
&
](
auto
...
is
)
{
return
make_const_array
(
_c
<
lens
[
is
]
>
...);
});
constexpr
auto
new_strides
=
sequence
(
strides
.
size
()
-
_c
<
2
>
,
[
&
](
auto
...
is
)
{
return
make_const_array
(
_c
<
strides
[
is
]
>
...);
});
return
make_shape
(
new_lens
,
new_strides
);
}
template
<
class
Tensor
>
constexpr
auto
gemm_get_matrix
()
{
constexpr
auto
lens
=
get_shape_c
<
Tensor
>
{}.
lens
;
constexpr
auto
strides
=
get_shape_c
<
Tensor
>
{}.
strides
;
constexpr
auto
m
=
lens
.
size
()
-
_c
<
2
>
;
constexpr
auto
n
=
lens
.
size
()
-
_c
<
1
>
;
constexpr
auto
new_lens
=
make_const_array
(
_c
<
lens
[
m
]
>
,
_c
<
lens
[
n
]
>
);
constexpr
auto
new_strides
=
make_const_array
(
_c
<
strides
[
m
]
>
,
_c
<
strides
[
n
]
>
);
return
make_shape
(
new_lens
,
new_strides
);
}
template
<
class
Tensor
,
class
T
>
constexpr
auto
gemm_batch_slice
(
Tensor
t
,
T
i
)
{
constexpr
auto
batch
=
gemm_get_batches
<
Tensor
>
();
constexpr
auto
matrix
=
gemm_get_matrix
<
Tensor
>
();
MIGRAPHX_ASSERT
((
batch
.
index
(
i
)
+
matrix
.
element_space
())
<=
t
.
get_shape
().
element_space
());
return
make_tensor_view
(
t
.
data
()
+
batch
.
index
(
i
),
matrix
);
}
template
<
class
BlocksPerBatch
,
class
T
,
class
...
Ts
>
constexpr
auto
gemm_batch_args
(
index
idx
,
BlocksPerBatch
bpb
,
T
x
,
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
// All tensors should have the same rank
static_assert
(
(
true
and
...
and
(
get_shape_c
<
T
>
{}.
lens
.
size
()
==
get_shape_c
<
Ts
>
{}.
lens
.
size
())));
if
constexpr
(
get_shape_c
<
T
>
{}.
lens
.
size
()
>
2
)
{
// Get the first batch since all batches should have the same number of elements
constexpr
auto
batch
=
gemm_get_batches
<
T
>
();
static_assert
(
(
true
and
...
and
(
batch
.
elements
()
==
gemm_get_batches
<
Ts
>
().
elements
())));
idx
.
group_stride
(
bpb
*
batch
.
elements
(),
[
&
](
auto
gidx
)
{
const
auto
batch_idx
=
gidx
/
bpb
;
f
(
gemm_batch_slice
(
x
,
batch_idx
),
gemm_batch_slice
(
xs
,
batch_idx
)...);
});
}
else
{
f
(
x
,
xs
...);
}
};
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment