Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
359bb1cd
Commit
359bb1cd
authored
Apr 08, 2023
by
Paul
Browse files
Merge branch 'ck-gemm-fused-transpose' into sd-opt
parents
1ac14290
55b363c9
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
266 additions
and
0 deletions
+266
-0
src/targets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
...ets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
+68
-0
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+8
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+3
-0
test/verify/gemm_add_relu.cpp
test/verify/gemm_add_relu.cpp
+45
-0
tools/tune_ck.py
tools/tune_ck.py
+142
-0
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/gemm_batcher.hpp
0 → 100644
View file @
359bb1cd
#ifndef MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
#define MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/index.hpp>
namespace
migraphx
{
template
<
class
Tensor
>
constexpr
auto
gemm_get_batches
()
{
constexpr
auto
lens
=
get_shape_c
<
Tensor
>
{}.
lens
;
constexpr
auto
strides
=
get_shape_c
<
Tensor
>
{}.
strides
;
constexpr
auto
new_lens
=
sequence
(
lens
.
size
()
-
_c
<
2
>
,
[
&
](
auto
...
is
)
{
return
make_const_array
(
_c
<
lens
[
is
]
>
...);
});
constexpr
auto
new_strides
=
sequence
(
strides
.
size
()
-
_c
<
2
>
,
[
&
](
auto
...
is
)
{
return
make_const_array
(
_c
<
strides
[
is
]
>
...);
});
return
make_shape
(
new_lens
,
new_strides
);
}
template
<
class
Tensor
>
constexpr
auto
gemm_get_matrix
()
{
constexpr
auto
lens
=
get_shape_c
<
Tensor
>
{}.
lens
;
constexpr
auto
strides
=
get_shape_c
<
Tensor
>
{}.
strides
;
constexpr
auto
m
=
lens
.
size
()
-
_c
<
2
>
;
constexpr
auto
n
=
lens
.
size
()
-
_c
<
1
>
;
constexpr
auto
new_lens
=
make_const_array
(
_c
<
lens
[
m
]
>
,
_c
<
lens
[
n
]
>
);
constexpr
auto
new_strides
=
make_const_array
(
_c
<
strides
[
m
]
>
,
_c
<
strides
[
n
]
>
);
return
make_shape
(
new_lens
,
new_strides
);
}
template
<
class
Tensor
,
class
T
>
constexpr
auto
gemm_batch_slice
(
Tensor
t
,
T
i
)
{
constexpr
auto
batch
=
gemm_get_batches
<
Tensor
>
();
constexpr
auto
matrix
=
gemm_get_matrix
<
Tensor
>
();
return
make_tensor_view
(
t
.
data
()
+
batch
.
index
(
i
),
matrix
);
}
template
<
class
BlocksPerBatch
,
class
T
,
class
...
Ts
>
constexpr
auto
gemm_batch_args
(
index
idx
,
BlocksPerBatch
bpb
,
T
x
,
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
// All tensors should have the same rank
static_assert
(
(
true
and
...
and
(
get_shape_c
<
T
>
{}.
lens
.
size
()
==
get_shape_c
<
Ts
>
{}.
lens
.
size
())));
if
constexpr
(
get_shape_c
<
T
>
{}.
lens
.
size
()
>
2
)
{
// Get the first batch since all batches should have the same number of elements
constexpr
auto
batch
=
gemm_get_batches
<
T
>
();
static_assert
(
(
true
and
...
and
(
batch
.
elements
()
==
gemm_get_batches
<
Ts
>
().
elements
())));
idx
.
group_stride
(
bpb
*
batch
.
elements
(),
[
&
](
auto
gidx
)
{
const
auto
batch_idx
=
gidx
/
bpb
;
f
(
gemm_batch_slice
(
x
,
batch_idx
),
gemm_batch_slice
(
xs
,
batch_idx
)...);
});
}
else
{
f
(
x
,
xs
...);
}
};
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
359bb1cd
...
@@ -130,6 +130,8 @@ struct index
...
@@ -130,6 +130,8 @@ struct index
return
blockDim
.
x
;
return
blockDim
.
x
;
}
}
#endif
#endif
constexpr
auto
ngroup
()
const
{
return
nglobal
()
/
max_nlocal
();
}
template
<
class
N
,
class
Stride
>
template
<
class
N
,
class
Stride
>
static
constexpr
auto
max_stride_iterations
(
N
n
,
Stride
stride
)
static
constexpr
auto
max_stride_iterations
(
N
n
,
Stride
stride
)
{
{
...
@@ -231,6 +233,12 @@ struct index
...
@@ -231,6 +233,12 @@ struct index
{
{
for_stride
<
true
>
(
local
,
n
,
nlocal
(),
f
);
for_stride
<
true
>
(
local
,
n
,
nlocal
(),
f
);
}
}
template
<
class
F
,
class
N
>
__device__
void
group_stride
(
N
n
,
F
f
)
const
{
for_stride
<
false
>
(
group
,
n
,
ngroup
(),
f
);
}
};
};
#ifdef MIGRAPHX_NLOCAL
#ifdef MIGRAPHX_NLOCAL
...
...
src/targets/gpu/target.cpp
View file @
359bb1cd
...
@@ -57,6 +57,7 @@
...
@@ -57,6 +57,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
...
@@ -135,6 +136,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -135,6 +136,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
dead_code_elimination
{},
fuse_mlir
{
&
ctx
},
fuse_mlir
{
&
ctx
},
dead_code_elimination
{},
dead_code_elimination
{},
fuse_ck
{
&
ctx
},
dead_code_elimination
{},
lowering
{
&
ctx
,
options
.
offload_copy
},
lowering
{
&
ctx
,
options
.
offload_copy
},
eliminate_contiguous
{
"gpu::contiguous"
},
eliminate_contiguous
{
"gpu::contiguous"
},
dead_code_elimination
{},
dead_code_elimination
{},
...
...
test/verify/gemm_add_relu.cpp
0 → 100644
View file @
359bb1cd
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct
gemm_add_relu
:
verify_program
<
gemm_add_relu
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
a
=
mm
->
add_parameter
(
"1"
,
{
migraphx
::
shape
::
half_type
,
{
2
,
3
}});
auto
b
=
mm
->
add_parameter
(
"2"
,
{
migraphx
::
shape
::
half_type
,
{
3
,
4
}});
auto
c
=
mm
->
add_parameter
(
"3"
,
{
migraphx
::
shape
::
half_type
,
{
2
,
4
}});
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
c
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
add
);
return
p
;
}
};
tools/tune_ck.py
0 → 100644
View file @
359bb1cd
import
os
,
json
,
subprocess
,
tempfile
,
sys
,
argparse
,
contextlib
,
multiprocessing
,
multiprocessing
.
dummy
@
contextlib
.
contextmanager
def
tmp_file
(
dump
=
None
):
tmp_name
=
None
try
:
with
tempfile
.
NamedTemporaryFile
(
mode
=
'w+'
,
delete
=
False
)
as
f
:
tmp_name
=
f
.
name
if
dump
:
dump
(
f
)
yield
tmp_name
finally
:
os
.
unlink
(
tmp_name
)
def
pretty_print
(
obj
):
print
(
json
.
dumps
(
obj
,
indent
=
2
))
def
run_driver
(
b
):
print
(
b
)
with
tmp_file
(
lambda
tf
:
json
.
dump
(
b
,
tf
))
as
tf
:
if
not
os
.
path
.
exists
(
'./bin/gpu-driver'
):
print
(
"./bin/gpu-driver not found"
)
os
.
abort
()
cp
=
subprocess
.
run
(
'./bin/gpu-driver {}'
.
format
(
tf
),
capture_output
=
True
,
shell
=
True
)
print
(
cp
.
stderr
.
decode
())
cp
.
check_returncode
()
for
line
in
cp
.
stdout
.
decode
().
split
(
"
\n
"
):
s
=
line
.
strip
()
if
not
s
:
continue
if
not
']: '
in
s
:
continue
yield
s
.
split
(
']: '
)[
1
].
strip
()
def
convert_to_float
(
s
):
return
s
[:
-
2
]
def
get_device_time
(
s
):
fields
=
s
.
split
(
','
)
return
convert_to_float
(
fields
[
-
1
].
strip
())
def
run_driver_ck
(
config
,
tuning
,
iterations
):
b
=
{
'settings'
:
{
'iterations'
:
iterations
},
'compile_op'
:
{
'name'
:
'ck_gemm'
,
'check'
:
True
,
'tuning_val'
:
tuning
,
'inputs'
:
config
}
}
return
run_driver
(
b
)
def
benchmark_ck
(
config
,
tuning
):
try
:
for
line
in
run_driver_ck
(
config
,
tuning
,
100
):
dtime
=
get_device_time
(
line
)
print
(
dtime
)
return
float
(
dtime
)
print
(
"Failed"
)
sys
.
exit
(
1
)
except
:
return
sys
.
float_info
.
max
def
benchmark
(
config
,
size
):
times
=
[
benchmark_ck
(
config
,
i
)
for
i
in
range
(
size
)]
return
times
.
index
(
min
(
times
))
def
parse_log
(
f
):
for
line
in
open
(
f
).
readlines
():
line
=
line
.
strip
()
if
not
line
.
startswith
(
'ck_gemm:'
):
continue
line
=
line
[
len
(
'ck_gemm:'
):].
strip
()
config
=
json
.
loads
(
line
)
yield
config
def
precompile
(
x
):
try
:
list
(
run_driver_ck
(
x
[
0
],
x
[
1
],
0
))
except
:
pass
def
precompile_log
(
f
,
n
):
solutions
=
((
config
,
i
)
for
config
in
parse_log
(
f
)
for
i
in
range
(
n
))
with
multiprocessing
.
Pool
(
24
)
as
p
:
list
(
p
.
imap
(
precompile
,
solutions
))
def
benchmark_log
(
f
,
n
):
result
=
[]
for
config
in
parse_log
(
f
):
tuned
=
benchmark
(
config
,
n
)
print
(
"Tuned:"
,
tuned
)
result
.
append
([
config
,
tuned
])
return
result
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple tuner for CK gemms"
)
parser
.
add_argument
(
'--log'
,
'-l'
,
type
=
str
,
metavar
=
'file'
,
help
=
'Path to logfile'
)
parser
.
add_argument
(
'--out'
,
'-o'
,
type
=
str
,
metavar
=
'file'
,
help
=
'Output json file to save tunings'
)
parser
.
add_argument
(
'--precompile'
,
'-p'
,
action
=
'store_true'
,
help
=
'Precompile kernels first in parallel'
)
parser
.
add_argument
(
'-n'
,
type
=
int
,
help
=
'Number of instances to tune'
)
args
=
parser
.
parse_args
()
return
args
def
run
(
args
):
if
(
args
.
precompile
):
precompile_log
(
args
.
log
,
args
.
n
)
tuned
=
benchmark_log
(
args
.
log
,
args
.
n
)
json
.
dump
(
tuned
,
open
(
args
.
out
,
'w+'
))
run
(
parse_args
())
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment