Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c96139f8
Commit
c96139f8
authored
Oct 03, 2023
by
Alan Turner
Browse files
Move common functions to ck.hpp + other cleanup
parent
a9b32b71
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
196 additions
and
466 deletions
+196
-466
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+3
-3
src/targets/gpu/include/migraphx/gpu/ck.hpp
src/targets/gpu/include/migraphx/gpu/ck.hpp
+182
-0
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+4
-227
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+6
-232
test/verify/ck_gemm_softmax_gemm.cpp
test/verify/ck_gemm_softmax_gemm.cpp
+1
-4
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
c96139f8
...
@@ -76,7 +76,7 @@ MIGRAPHX_REGISTER_OP(ck_gemm);
...
@@ -76,7 +76,7 @@ MIGRAPHX_REGISTER_OP(ck_gemm);
struct
ck_gemm_softmax_gemm
struct
ck_gemm_softmax_gemm
{
{
operation
op
=
make_op
(
"dot"
);
operation
op
=
make_op
(
"dot"
);
double
scale
=
1.0
;
float
scale
=
1.0
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -203,7 +203,7 @@ struct find_ck_gemm
...
@@ -203,7 +203,7 @@ struct find_ck_gemm
}
}
};
};
static
auto
is_mul_module
(
module
&
m
)
auto
is_mul_module
(
module
&
m
)
{
{
auto
is_mul
=
auto
is_mul
=
match
::
arg
(
0
)(
match
::
name
(
"mul"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
name
(
"@param"
))));
match
::
arg
(
0
)(
match
::
name
(
"mul"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
name
(
"@param"
))));
...
@@ -243,7 +243,7 @@ struct find_ck_gemm_softmax_gemm
...
@@ -243,7 +243,7 @@ struct find_ck_gemm_softmax_gemm
if
(
not
ck_gemm_softmax_gemm
::
is_ck_supported_type
(
gemm1_ins
->
get_shape
().
type
()))
if
(
not
ck_gemm_softmax_gemm
::
is_ck_supported_type
(
gemm1_ins
->
get_shape
().
type
()))
return
;
return
;
double
scale
=
1.0
;
float
scale
=
1.0
;
scale_lit
->
eval
().
visit
([
&
](
const
auto
s
)
{
scale_lit
->
eval
().
visit
([
&
](
const
auto
s
)
{
// CK only supports single-valued scale
// CK only supports single-valued scale
if
(
std
::
all_of
(
if
(
std
::
all_of
(
...
...
src/targets/gpu/include/migraphx/gpu/ck.hpp
0 → 100644
View file @
c96139f8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_CK_HPP
#define MIGRAPHX_GUARD_GPU_CK_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/compile_src.hpp>
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_LOG_CK_GEMM
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_DEBUG
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TUNE_CK
);
// NOLINTNEXTLINE
const
char
*
const
disable_warning_pragma
=
R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__"
;
template
<
class
P
>
std
::
string
ck_disable_warnings
(
P
p
)
{
return
interpolate_string
(
disable_warning_pragma
,
{{
"content"
,
std
::
string
{
p
.
first
,
p
.
second
}}});
}
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
create_ck_header_strings
()
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
result
;
auto
ck_headers
=
ck
::
host
::
GetHeaders
();
std
::
transform
(
ck_headers
.
begin
(),
ck_headers
.
end
(),
std
::
inserter
(
result
,
result
.
begin
()),
[
&
](
auto
&&
p
)
{
return
std
::
make_pair
(
p
.
first
,
ck_disable_warnings
(
p
.
second
));
});
return
result
;
}
static
std
::
vector
<
src_file
>
create_ck_headers
()
{
static
const
auto
&
header_strings
=
create_ck_header_strings
();
std
::
vector
<
src_file
>
srcs
;
std
::
transform
(
header_strings
.
begin
(),
header_strings
.
end
(),
std
::
back_inserter
(
srcs
),
[
&
](
auto
&&
p
)
{
return
src_file
{
fs
::
path
{
p
.
first
},
{
p
.
second
.
data
(),
p
.
second
.
data
()
+
p
.
second
.
size
()}};
});
return
srcs
;
}
static
const
std
::
vector
<
src_file
>&
ck_headers
()
{
static
const
auto
&
headers
=
create_ck_headers
();
return
headers
;
}
inline
bool
transposed_matrix
(
const
shape
&
s
)
{
return
s
.
strides
().
back
()
!=
1
;
}
inline
float
matrix_distance
(
const
shape
&
x
,
const
shape
&
y
)
{
if
(
x
.
type
()
!=
y
.
type
())
return
std
::
numeric_limits
<
float
>::
max
();
if
(
transposed_matrix
(
x
)
!=
transposed_matrix
(
y
))
return
std
::
numeric_limits
<
float
>::
max
();
auto
sum_squared
=
std
::
inner_product
(
x
.
lens
().
rbegin
(),
x
.
lens
().
rbegin
()
+
2
,
y
.
lens
().
rbegin
(),
0
,
std
::
plus
<>
{},
[](
auto
a
,
auto
b
)
{
return
(
a
-
b
)
*
(
a
-
b
);
});
return
std
::
sqrt
(
sum_squared
);
}
inline
std
::
string
get_layout
(
const
shape
&
s
)
{
return
transposed_matrix
(
s
)
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
}
inline
ck
::
host
::
DataType
get_type
(
const
shape
&
s
)
{
if
(
s
.
type
()
==
shape
::
half_type
)
return
ck
::
host
::
DataType
::
Half
;
else
if
(
s
.
type
()
==
shape
::
float_type
)
return
ck
::
host
::
DataType
::
Float
;
else
if
(
s
.
type
()
==
shape
::
int8_type
)
return
ck
::
host
::
DataType
::
Int8
;
else
if
(
s
.
type
()
==
shape
::
int32_type
)
return
ck
::
host
::
DataType
::
Int32
;
MIGRAPHX_THROW
(
"Unsupported ck type"
);
}
inline
std
::
size_t
get_batch_count
(
const
shape
&
s
)
{
return
std
::
accumulate
(
s
.
lens
().
rbegin
()
+
2
,
s
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
inline
void
fold_batch_dims
(
shape
&
s
)
{
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
return
;
auto
batch_count
=
get_batch_count
(
s
);
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
if
(
transposed_matrix
(
s
))
s
=
shape
{
s
.
type
(),
{
m1
,
m2
*
batch_count
}};
else
s
=
shape
{
s
.
type
(),
{
m1
*
batch_count
,
m2
}};
}
inline
void
remove_batch_dims
(
shape
&
s
)
{
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
return
;
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
}
inline
bool
standard_batch
(
const
shape
&
s
)
{
if
(
s
.
lens
().
size
()
<
3
)
return
true
;
std
::
vector
<
std
::
size_t
>
lens
(
s
.
lens
().
begin
(),
s
.
lens
().
end
()
-
2
);
std
::
vector
<
std
::
size_t
>
strides
(
s
.
strides
().
begin
(),
s
.
strides
().
end
()
-
2
);
auto
base
=
*
(
s
.
lens
().
end
()
-
2
)
*
*
(
s
.
lens
().
end
()
-
1
);
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
stride
)
{
return
stride
/
base
;
});
return
shape
{
s
.
type
(),
lens
,
strides
}.
standard
();
}
inline
bool
can_fold_batch
(
const
std
::
vector
<
shape
>&
inputs
)
{
const
auto
&
b_shape
=
inputs
[
1
];
if
(
std
::
any_of
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
[](
auto
input
)
{
return
not
standard_batch
(
input
);
}))
return
false
;
const
auto
&
b_strides
=
b_shape
.
strides
();
return
std
::
all_of
(
b_strides
.
begin
(),
b_strides
.
end
()
-
2
,
[](
auto
stride
)
{
return
stride
==
0
;
});
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_CK_HPP
src/targets/gpu/jit/ck_gemm.cpp
View file @
c96139f8
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/ck.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/compile_gen.hpp>
...
@@ -37,8 +38,6 @@
...
@@ -37,8 +38,6 @@
#include <migraphx/reduce_dims.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include "ck/host/device_gemm_multiple_d.hpp"
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -46,12 +45,6 @@ namespace gpu {
...
@@ -46,12 +45,6 @@ namespace gpu {
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_LOG_CK_GEMM
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_TUNING
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_TUNING_VALUE
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_DEBUG
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TUNE_CK
);
// NOLINTNEXTLINE
// NOLINTNEXTLINE
static
const
char
*
const
ck_gemm_kernel
=
R"__migraphx__(
static
const
char
*
const
ck_gemm_kernel
=
R"__migraphx__(
#include <args.hpp>
#include <args.hpp>
...
@@ -79,230 +72,18 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
...
@@ -79,230 +72,18 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
)__migraphx__"
;
)__migraphx__"
;
// NOLINTNEXTLINE
static
const
char
*
const
disable_warning_pragma
=
R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__"
;
template
<
class
P
>
static
std
::
string
ck_disable_warnings
(
P
p
)
{
return
interpolate_string
(
disable_warning_pragma
,
{{
"content"
,
std
::
string
{
p
.
first
,
p
.
second
}}});
}
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
create_ck_header_strings
()
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
result
;
auto
ck_headers
=
ck
::
host
::
GetHeaders
();
std
::
transform
(
ck_headers
.
begin
(),
ck_headers
.
end
(),
std
::
inserter
(
result
,
result
.
begin
()),
[
&
](
auto
&&
p
)
{
return
std
::
make_pair
(
p
.
first
,
ck_disable_warnings
(
p
.
second
));
});
return
result
;
}
static
std
::
vector
<
src_file
>
create_ck_headers
()
{
static
const
auto
&
header_strings
=
create_ck_header_strings
();
std
::
vector
<
src_file
>
srcs
;
std
::
transform
(
header_strings
.
begin
(),
header_strings
.
end
(),
std
::
back_inserter
(
srcs
),
[
&
](
auto
&&
p
)
{
return
src_file
{
fs
::
path
{
p
.
first
},
{
p
.
second
.
data
(),
p
.
second
.
data
()
+
p
.
second
.
size
()}};
});
return
srcs
;
}
static
const
std
::
vector
<
src_file
>&
ck_headers
()
{
static
const
auto
&
headers
=
create_ck_headers
();
return
headers
;
}
static
bool
transposed_matrix
(
const
shape
&
s
)
{
return
s
.
strides
().
back
()
!=
1
;
}
using
tuning_entry
=
std
::
pair
<
std
::
vector
<
shape
>
,
size_t
>
;
static
std
::
vector
<
tuning_entry
>
read_tuning
(
const
std
::
string
&
s
)
{
if
(
not
fs
::
exists
(
s
))
return
{};
return
from_value
<
std
::
vector
<
tuning_entry
>>
(
from_json_string
(
read_string
(
s
)));
}
static
float
matrix_distance
(
const
shape
&
x
,
const
shape
&
y
)
{
if
(
x
.
type
()
!=
y
.
type
())
return
std
::
numeric_limits
<
float
>::
max
();
if
(
transposed_matrix
(
x
)
!=
transposed_matrix
(
y
))
return
std
::
numeric_limits
<
float
>::
max
();
auto
sum_squared
=
std
::
inner_product
(
x
.
lens
().
rbegin
(),
x
.
lens
().
rbegin
()
+
2
,
y
.
lens
().
rbegin
(),
0
,
std
::
plus
<>
{},
[](
auto
a
,
auto
b
)
{
return
(
a
-
b
)
*
(
a
-
b
);
});
return
std
::
sqrt
(
sum_squared
);
}
static
std
::
size_t
get_tuning_for
(
const
std
::
vector
<
shape
>&
inputs
)
{
static
auto
tuning
=
read_tuning
(
string_value_of
(
MIGRAPHX_CK_TUNING
{},
""
));
if
(
tuning
.
empty
())
{
std
::
cout
<<
"*********** Warning: No CK tuning! for config:"
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
0
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
1
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
2
]
<<
std
::
endl
;
}
auto
it
=
std
::
find_if
(
tuning
.
begin
(),
tuning
.
end
(),
[
&
](
const
auto
&
p
)
{
return
p
.
first
==
inputs
;
});
if
(
it
==
tuning
.
end
())
{
std
::
cout
<<
"*********** Warning: CK tuning missing for config!"
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
0
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
1
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
2
]
<<
std
::
endl
;
std
::
vector
<
std
::
pair
<
float
,
std
::
size_t
>>
w
;
std
::
transform
(
tuning
.
begin
(),
tuning
.
end
(),
std
::
back_inserter
(
w
),
[
&
](
const
auto
&
p
)
{
if
(
inputs
.
size
()
<
3
or
p
.
first
.
size
()
<
3
)
MIGRAPHX_THROW
(
"Invalid CK config"
);
auto
avg_distance
=
std
::
inner_product
(
p
.
first
.
begin
(),
p
.
first
.
begin
()
+
3
,
inputs
.
begin
(),
0.0
f
,
std
::
plus
<>
{},
[](
const
auto
&
x
,
const
auto
&
y
)
{
return
matrix_distance
(
x
,
y
)
/
3.0
f
;
});
return
std
::
make_pair
(
avg_distance
,
p
.
second
);
});
std
::
sort
(
w
.
begin
(),
w
.
end
());
std
::
size_t
default_value
=
4
;
if
(
not
w
.
empty
())
default_value
=
w
.
front
().
second
;
auto
tuning_val
=
value_of
(
MIGRAPHX_CK_TUNING_VALUE
{},
default_value
);
std
::
cout
<<
"*********** Warning: CK try tuning: "
<<
tuning_val
<<
std
::
endl
;
return
tuning_val
;
}
return
it
->
second
;
}
struct
ck_gemm_compiler
:
compiler
<
ck_gemm_compiler
>
struct
ck_gemm_compiler
:
compiler
<
ck_gemm_compiler
>
{
{
static
std
::
string
get_layout
(
const
shape
&
s
)
{
return
transposed_matrix
(
s
)
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
}
static
ck
::
host
::
DataType
get_type
(
const
shape
&
s
)
{
if
(
s
.
type
()
==
shape
::
half_type
)
return
ck
::
host
::
DataType
::
Half
;
else
if
(
s
.
type
()
==
shape
::
float_type
)
return
ck
::
host
::
DataType
::
Float
;
else
if
(
s
.
type
()
==
shape
::
int8_type
)
return
ck
::
host
::
DataType
::
Int8
;
else
if
(
s
.
type
()
==
shape
::
int32_type
)
return
ck
::
host
::
DataType
::
Int32
;
MIGRAPHX_THROW
(
"Unsupported ck type"
);
}
template
<
class
Iterator
,
class
F
>
static
std
::
string
ck_tuple
(
Iterator
start
,
Iterator
last
,
F
f
)
{
std
::
vector
<
std
::
string
>
s
;
std
::
transform
(
start
,
last
,
std
::
back_inserter
(
s
),
f
);
return
"ck::Tuple<"
+
join_strings
(
s
,
","
)
+
">"
;
}
static
std
::
vector
<
shape
>
adjust_inputs
(
std
::
vector
<
shape
>
inputs
,
bool
&
swap_inputs
)
{
swap_inputs
=
false
;
auto
c_shape
=
inputs
.
back
();
if
(
not
transposed_matrix
(
c_shape
))
return
inputs
;
std
::
vector
<
int64_t
>
perm
(
c_shape
.
lens
().
size
());
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
0
);
std
::
swap
(
perm
[
perm
.
size
()
-
1
],
perm
[
perm
.
size
()
-
2
]);
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
shape
s
)
{
return
reorder_shape
(
s
,
perm
);
});
swap_inputs
=
true
;
return
inputs
;
}
static
std
::
size_t
get_batch_count
(
const
shape
&
s
)
{
return
std
::
accumulate
(
s
.
lens
().
rbegin
()
+
2
,
s
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
static
void
fold_batch_dims
(
shape
&
s
)
{
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
return
;
auto
batch_count
=
get_batch_count
(
s
);
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
if
(
transposed_matrix
(
s
))
s
=
shape
{
s
.
type
(),
{
m1
,
m2
*
batch_count
}};
else
s
=
shape
{
s
.
type
(),
{
m1
*
batch_count
,
m2
}};
}
static
void
remove_batch_dims
(
shape
&
s
)
{
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
return
;
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
,
"gpu::ck_gemm"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
,
"gpu::ck_gemm"
};
}
static
bool
standard_batch
(
const
shape
&
s
)
{
if
(
s
.
lens
().
size
()
<
3
)
return
true
;
std
::
vector
<
std
::
size_t
>
lens
(
s
.
lens
().
begin
(),
s
.
lens
().
end
()
-
2
);
std
::
vector
<
std
::
size_t
>
strides
(
s
.
strides
().
begin
(),
s
.
strides
().
end
()
-
2
);
auto
base
=
*
(
s
.
lens
().
end
()
-
2
)
*
*
(
s
.
lens
().
end
()
-
1
);
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
stride
)
{
return
stride
/
base
;
});
return
shape
{
s
.
type
(),
lens
,
strides
}.
standard
();
}
bool
can_fold_batch
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
const
auto
&
b_shape
=
inputs
[
1
];
if
(
std
::
any_of
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
[](
auto
input
)
{
return
not
standard_batch
(
input
);
}))
return
false
;
const
auto
&
b_strides
=
b_shape
.
strides
();
return
std
::
all_of
(
b_strides
.
begin
(),
b_strides
.
end
()
-
2
,
[](
auto
stride
)
{
return
stride
==
0
;
});
}
ck
::
host
::
device_gemm_multiple_d
::
Problem
create_problem
(
const
std
::
vector
<
shape
>&
inputs
,
ck
::
host
::
device_gemm_multiple_d
::
Problem
create_problem
(
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
const
value
&
v
)
const
{
{
const
auto
&
a_shape
=
inputs
[
0
];
const
auto
&
a_shape
=
inputs
[
0
];
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
c_shape
=
inputs
.
back
();
const
auto
&
c_shape
=
inputs
.
back
();
// cppcheck-suppress unreadVariable
auto
rank
=
a_shape
.
ndim
();
auto
rank
=
a_shape
.
ndim
();
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
m
=
c_shape
.
lens
()[
rank
-
2
];
auto
m
=
c_shape
.
lens
()[
rank
-
2
];
m
=
can_fold_batch
(
inputs
)
?
m
*
batch_count
:
m
;
m
=
can_fold_batch
(
inputs
)
?
m
*
batch_count
:
m
;
...
@@ -352,12 +133,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -352,12 +133,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
const
auto
&
a_shape
=
inputs
[
0
];
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
c_shape
=
inputs
.
back
();
const
auto
&
c_shape
=
inputs
.
back
();
auto
tuning_value
=
v
.
get
(
"tuning_value"
,
4
);
auto
tuning_value
=
v
.
get
(
"tuning_value"
,
0
);
if
(
not
v
.
contains
(
"tuning_value"
))
tuning_value
=
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
});
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
problem
=
create_problem
(
inputs
,
v
);
auto
problem
=
create_problem
(
inputs
,
v
);
...
...
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
View file @
c96139f8
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/gpu/ck.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
...
@@ -37,8 +38,6 @@
...
@@ -37,8 +38,6 @@
#include <migraphx/reduce_dims.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -46,12 +45,6 @@ namespace gpu {
...
@@ -46,12 +45,6 @@ namespace gpu {
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_LOG_CK_GEMM
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_TUNING
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_TUNING_VALUE
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_DEBUG
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TUNE_CK
);
// NOLINTNEXTLINE
// NOLINTNEXTLINE
static
const
char
*
const
ck_gemm_softmax_gemm_kernel
=
R"__migraphx__(
static
const
char
*
const
ck_gemm_softmax_gemm_kernel
=
R"__migraphx__(
#include <args.hpp>
#include <args.hpp>
...
@@ -82,236 +75,22 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
...
@@ -82,236 +75,22 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
)__migraphx__"
;
)__migraphx__"
;
// NOLINTNEXTLINE
static
const
char
*
const
disable_warning_pragma
=
R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__"
;
template
<
class
P
>
static
std
::
string
ck_disable_warnings
(
P
p
)
{
return
interpolate_string
(
disable_warning_pragma
,
{{
"content"
,
std
::
string
{
p
.
first
,
p
.
second
}}});
}
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
create_ck_header_strings
()
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
result
;
auto
ck_headers
=
ck
::
host
::
GetHeaders
();
std
::
transform
(
ck_headers
.
begin
(),
ck_headers
.
end
(),
std
::
inserter
(
result
,
result
.
begin
()),
[
&
](
auto
&&
p
)
{
return
std
::
make_pair
(
p
.
first
,
ck_disable_warnings
(
p
.
second
));
});
return
result
;
}
static
std
::
vector
<
src_file
>
create_ck_headers
()
{
static
const
auto
&
header_strings
=
create_ck_header_strings
();
std
::
vector
<
src_file
>
srcs
;
std
::
transform
(
header_strings
.
begin
(),
header_strings
.
end
(),
std
::
back_inserter
(
srcs
),
[
&
](
auto
&&
p
)
{
return
src_file
{
fs
::
path
{
p
.
first
},
{
p
.
second
.
data
(),
p
.
second
.
data
()
+
p
.
second
.
size
()}};
});
return
srcs
;
}
static
const
std
::
vector
<
src_file
>&
ck_headers
()
{
static
const
auto
&
headers
=
create_ck_headers
();
return
headers
;
}
static
bool
transposed_matrix
(
const
shape
&
s
)
{
return
s
.
strides
().
back
()
!=
1
;
}
using
tuning_entry
=
std
::
pair
<
std
::
vector
<
shape
>
,
size_t
>
;
static
std
::
vector
<
tuning_entry
>
read_tuning
(
const
std
::
string
&
s
)
{
if
(
not
fs
::
exists
(
s
))
return
{};
return
from_value
<
std
::
vector
<
tuning_entry
>>
(
from_json_string
(
read_string
(
s
)));
}
static
float
matrix_distance
(
const
shape
&
x
,
const
shape
&
y
)
{
if
(
x
.
type
()
!=
y
.
type
())
return
std
::
numeric_limits
<
float
>::
max
();
if
(
transposed_matrix
(
x
)
!=
transposed_matrix
(
y
))
return
std
::
numeric_limits
<
float
>::
max
();
auto
sum_squared
=
std
::
inner_product
(
x
.
lens
().
rbegin
(),
x
.
lens
().
rbegin
()
+
2
,
y
.
lens
().
rbegin
(),
0
,
std
::
plus
<>
{},
[](
auto
a
,
auto
b
)
{
return
(
a
-
b
)
*
(
a
-
b
);
});
return
std
::
sqrt
(
sum_squared
);
}
static
std
::
size_t
get_tuning_for
(
const
std
::
vector
<
shape
>&
inputs
)
{
static
auto
tuning
=
read_tuning
(
string_value_of
(
MIGRAPHX_CK_TUNING
{},
""
));
if
(
tuning
.
empty
())
{
std
::
cout
<<
"*********** Warning: No CK tuning! for config:"
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
0
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
1
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
2
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
3
]
<<
std
::
endl
;
}
auto
it
=
std
::
find_if
(
tuning
.
begin
(),
tuning
.
end
(),
[
&
](
const
auto
&
p
)
{
return
p
.
first
==
inputs
;
});
if
(
it
==
tuning
.
end
())
{
std
::
cout
<<
"*********** Warning: CK tuning missing for config!"
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
0
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
1
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
2
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
3
]
<<
std
::
endl
;
std
::
vector
<
std
::
pair
<
float
,
std
::
size_t
>>
w
;
std
::
transform
(
tuning
.
begin
(),
tuning
.
end
(),
std
::
back_inserter
(
w
),
[
&
](
const
auto
&
p
)
{
if
(
inputs
.
size
()
<
3
or
p
.
first
.
size
()
<
3
)
MIGRAPHX_THROW
(
"Invalid CK config"
);
auto
avg_distance
=
std
::
inner_product
(
p
.
first
.
begin
(),
p
.
first
.
begin
()
+
3
,
inputs
.
begin
(),
0.0
f
,
std
::
plus
<>
{},
[](
const
auto
&
x
,
const
auto
&
y
)
{
return
matrix_distance
(
x
,
y
)
/
3.0
f
;
});
return
std
::
make_pair
(
avg_distance
,
p
.
second
);
});
std
::
sort
(
w
.
begin
(),
w
.
end
());
std
::
size_t
default_value
=
5
;
if
(
not
w
.
empty
())
default_value
=
w
.
front
().
second
;
auto
tuning_val
=
value_of
(
MIGRAPHX_CK_TUNING_VALUE
{},
default_value
);
std
::
cout
<<
"*********** Warning: CK try tuning: "
<<
tuning_val
<<
std
::
endl
;
return
tuning_val
;
}
return
it
->
second
;
}
struct
ck_gemm_softmax_gemm_compiler
:
compiler
<
ck_gemm_softmax_gemm_compiler
>
struct
ck_gemm_softmax_gemm_compiler
:
compiler
<
ck_gemm_softmax_gemm_compiler
>
{
{
static
std
::
string
get_layout
(
const
shape
&
s
)
{
return
transposed_matrix
(
s
)
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
}
static
ck
::
host
::
DataType
get_type
(
const
shape
&
s
)
{
if
(
s
.
type
()
==
shape
::
half_type
)
return
ck
::
host
::
DataType
::
Half
;
else
if
(
s
.
type
()
==
shape
::
float_type
)
return
ck
::
host
::
DataType
::
Float
;
else
if
(
s
.
type
()
==
shape
::
int8_type
)
return
ck
::
host
::
DataType
::
Int8
;
else
if
(
s
.
type
()
==
shape
::
int32_type
)
return
ck
::
host
::
DataType
::
Int32
;
MIGRAPHX_THROW
(
"Unsupported ck type"
);
}
template
<
class
Iterator
,
class
F
>
static
std
::
string
ck_tuple
(
Iterator
start
,
Iterator
last
,
F
f
)
{
std
::
vector
<
std
::
string
>
s
;
std
::
transform
(
start
,
last
,
std
::
back_inserter
(
s
),
f
);
return
"ck::Tuple<"
+
join_strings
(
s
,
","
)
+
">"
;
}
static
std
::
vector
<
shape
>
adjust_inputs
(
std
::
vector
<
shape
>
inputs
,
bool
&
swap_inputs
)
{
swap_inputs
=
false
;
auto
c_shape
=
inputs
.
back
();
if
(
not
transposed_matrix
(
c_shape
))
return
inputs
;
std
::
vector
<
int64_t
>
perm
(
c_shape
.
lens
().
size
());
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
0
);
std
::
swap
(
perm
[
perm
.
size
()
-
1
],
perm
[
perm
.
size
()
-
2
]);
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
shape
s
)
{
return
reorder_shape
(
s
,
perm
);
});
swap_inputs
=
true
;
return
inputs
;
}
static
std
::
size_t
get_batch_count
(
const
shape
&
s
)
{
return
std
::
accumulate
(
s
.
lens
().
rbegin
()
+
2
,
s
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
static
void
fold_batch_dims
(
shape
&
s
)
{
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
return
;
auto
batch_count
=
get_batch_count
(
s
);
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
if
(
transposed_matrix
(
s
))
s
=
shape
{
s
.
type
(),
{
m1
,
m2
*
batch_count
}};
else
s
=
shape
{
s
.
type
(),
{
m1
*
batch_count
,
m2
}};
}
static
void
remove_batch_dims
(
shape
&
s
)
{
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
return
;
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
}
std
::
vector
<
std
::
string
>
names
()
const
std
::
vector
<
std
::
string
>
names
()
const
{
{
return
{
"ck_gemm_softmax_gemm"
,
"gpu::ck_gemm_softmax_gemm"
};
return
{
"ck_gemm_softmax_gemm"
,
"gpu::ck_gemm_softmax_gemm"
};
}
}
static
bool
standard_batch
(
const
shape
&
s
)
{
if
(
s
.
lens
().
size
()
<
3
)
return
true
;
std
::
vector
<
std
::
size_t
>
lens
(
s
.
lens
().
begin
(),
s
.
lens
().
end
()
-
2
);
std
::
vector
<
std
::
size_t
>
strides
(
s
.
strides
().
begin
(),
s
.
strides
().
end
()
-
2
);
auto
base
=
*
(
s
.
lens
().
end
()
-
2
)
*
*
(
s
.
lens
().
end
()
-
1
);
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
stride
)
{
return
stride
/
base
;
});
return
shape
{
s
.
type
(),
lens
,
strides
}.
standard
();
}
bool
can_fold_batch
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
const
auto
&
b_shape
=
inputs
[
1
];
if
(
std
::
any_of
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
[](
auto
input
)
{
return
not
standard_batch
(
input
);
}))
return
false
;
const
auto
&
b_strides
=
b_shape
.
strides
();
return
std
::
all_of
(
b_strides
.
begin
(),
b_strides
.
end
()
-
2
,
[](
auto
stride
)
{
return
stride
==
0
;
});
}
ck
::
host
::
device_batched_gemm_softmax_gemm
::
Problem
ck
::
host
::
device_batched_gemm_softmax_gemm
::
Problem
create_problem
(
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
create_problem
(
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
)
const
{
{
const
auto
&
a_shape
=
inputs
[
0
];
const
auto
&
a_shape
=
inputs
[
0
];
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
b1_shape
=
inputs
[
2
];
const
auto
&
b1_shape
=
inputs
[
2
];
const
auto
&
c_shape
=
inputs
.
back
();
const
auto
&
c_shape
=
inputs
.
back
();
// cppcheck-suppress unreadVariable
auto
rank
=
a_shape
.
ndim
();
auto
rank
=
a_shape
.
ndim
();
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
m
=
c_shape
.
lens
()[
rank
-
2
];
auto
m
=
c_shape
.
lens
()[
rank
-
2
];
m
=
can_fold_batch
(
inputs
)
?
m
*
batch_count
:
m
;
m
=
can_fold_batch
(
inputs
)
?
m
*
batch_count
:
m
;
...
@@ -349,13 +128,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -349,13 +128,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
const
auto
&
a_shape
=
inputs
[
0
];
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
b1_shape
=
inputs
[
2
];
const
auto
&
c_shape
=
inputs
.
back
();
const
auto
&
c_shape
=
inputs
.
back
();
auto
tuning_value
=
v
.
get
(
"tuning_value"
,
4
);
auto
tuning_value
=
v
.
get
(
"tuning_value"
,
5
);
if
(
not
v
.
contains
(
"tuning_value"
))
tuning_value
=
get_tuning_for
({
a_shape
,
b_shape
,
b1_shape
,
c_shape
});
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
problem
=
create_problem
(
inputs
,
v
);
auto
problem
=
create_problem
(
inputs
,
v
);
...
@@ -399,7 +173,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -399,7 +173,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"kernel"
,
options
.
kernel_name
}});
{
"kernel"
,
options
.
kernel_name
}});
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
}
}
...
...
test/verify/ck_gemm_softmax_gemm.cpp
View file @
c96139f8
...
@@ -35,17 +35,14 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
...
@@ -35,17 +35,14 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
auto
m2_elements
=
1
*
12
*
256
*
256
;
auto
m2_elements
=
m2_shape
.
elements
()
;
auto
a
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
a
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
b
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
auto
b
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
auto
b1
=
mm
->
add_parameter
(
"3"
,
m1_shape
);
auto
b1
=
mm
->
add_parameter
(
"3"
,
m1_shape
);
auto
c
=
mm
->
add_parameter
(
"4"
,
m1_shape
);
std
::
vector
<
float
>
eights
(
m2_elements
,
0.125
);
std
::
vector
<
float
>
eights
(
m2_elements
,
0.125
);
auto
eight
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
eights
});
auto
eight
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
eights
});
std
::
vector
<
float
>
zeros
(
m2_elements
,
0
);
std
::
vector
<
float
>
zeros
(
m2_elements
,
0
);
auto
zero
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
zeros
});
auto
zero
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
zeros
});
std
::
vector
<
float
>
ones
(
m2_elements
,
1
);
auto
one
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
ones
});
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
b
);
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
b
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment