Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8a9cfb25
Unverified
Commit
8a9cfb25
authored
Aug 26, 2019
by
mvermeulen
Committed by
GitHub
Aug 26, 2019
Browse files
Merge branch 'develop' into bugs_for_bert
parents
762fca25
7534546a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
144 additions
and
83 deletions
+144
-83
src/driver/argument_parser.hpp
src/driver/argument_parser.hpp
+47
-4
src/driver/main.cpp
src/driver/main.cpp
+24
-1
src/driver/perf.cpp
src/driver/perf.cpp
+17
-0
src/driver/perf.hpp
src/driver/perf.hpp
+1
-0
src/generate.cpp
src/generate.cpp
+11
-0
src/include/migraphx/generate.hpp
src/include/migraphx/generate.hpp
+10
-0
src/include/migraphx/requires.hpp
src/include/migraphx/requires.hpp
+4
-3
src/targets/gpu/quant_gemm.cpp
src/targets/gpu/quant_gemm.cpp
+30
-75
No files found.
src/driver/argument_parser.hpp
View file @
8a9cfb25
...
...
@@ -28,10 +28,32 @@ inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPHX_DRIVER_STATIC static
#endif
template
<
class
T
>
using
bare
=
std
::
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
namespace
detail
{
template
<
class
T
>
auto
is_container
(
int
,
T
&&
x
)
->
decltype
(
x
.
insert
(
x
.
end
(),
*
x
.
begin
()),
std
::
true_type
{});
template
<
class
T
>
std
::
false_type
is_container
(
float
,
T
&&
);
}
// namespace detail
template
<
class
T
>
struct
is_container
:
decltype
(
detail
::
is_container
(
int
(
0
),
std
::
declval
<
T
>
()))
{
};
template
<
class
T
>
using
is_multi_value
=
std
::
integral_constant
<
bool
,
(
is_container
<
T
>
{}
and
not
std
::
is_convertible
<
T
,
std
::
string
>
{})
>
;
template
<
class
T
>
struct
value_parser
{
template
<
MIGRAPHX_REQUIRES
(
not
std
::
is_enum
<
T
>{})
>
template
<
MIGRAPHX_REQUIRES
(
not
std
::
is_enum
<
T
>{}
and
not
is_multi_value
<
T
>
{}
)
>
static
T
apply
(
const
std
::
string
&
x
)
{
T
result
;
...
...
@@ -43,7 +65,7 @@ struct value_parser
return
result
;
}
template
<
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
template
<
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{}
and
not
is_multi_value
<
T
>
{}
)
>
static
T
apply
(
const
std
::
string
&
x
)
{
std
::
ptrdiff_t
i
;
...
...
@@ -54,6 +76,15 @@ struct value_parser
throw
std
::
runtime_error
(
"Failed to parse: "
+
x
);
return
static_cast
<
T
>
(
i
);
}
template
<
MIGRAPHX_REQUIRES
(
is_multi_value
<
T
>{}
and
not
std
::
is_enum
<
T
>
{})
>
static
T
apply
(
const
std
::
string
&
x
)
{
T
result
;
using
value_type
=
typename
T
::
value_type
;
result
.
insert
(
result
.
end
(),
value_parser
<
value_type
>::
apply
(
x
));
return
result
;
}
};
struct
argument_parser
...
...
@@ -69,6 +100,18 @@ struct argument_parser
unsigned
nargs
=
1
;
};
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_multi_value
<
T
>{})
>
std
::
string
as_string_value
(
const
T
&
x
)
{
return
to_string_range
(
x
);
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_multi_value
<
T
>{})
>
std
::
string
as_string_value
(
const
T
&
x
)
{
return
to_string
(
x
);
}
template
<
class
T
,
class
...
Fs
>
void
operator
()(
T
&
x
,
const
std
::
vector
<
std
::
string
>&
flags
,
Fs
...
fs
)
{
...
...
@@ -81,7 +124,7 @@ struct argument_parser
argument
&
arg
=
arguments
.
back
();
arg
.
type
=
migraphx
::
get_type_name
<
T
>
();
arg
.
default_value
=
to
_string
(
x
);
arg
.
default_value
=
as
_string
_value
(
x
);
migraphx
::
each_args
([
&
](
auto
f
)
{
f
(
x
,
arg
);
},
fs
...);
}
...
...
@@ -127,7 +170,7 @@ struct argument_parser
MIGRAPHX_DRIVER_STATIC
auto
append
()
{
return
write_action
([](
auto
&
,
auto
&
x
,
auto
&
params
)
{
using
type
=
typename
decltype
(
params
)
::
value_type
;
using
type
=
typename
bare
<
decltype
(
params
)
>
::
value_type
;
std
::
transform
(
params
.
begin
(),
params
.
end
(),
std
::
inserter
(
x
,
x
.
end
()),
...
...
src/driver/main.cpp
View file @
8a9cfb25
...
...
@@ -8,6 +8,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
...
...
@@ -80,11 +81,13 @@ struct compiler
{
loader
l
;
bool
gpu
=
true
;
std
::
vector
<
std
::
string
>
fill1
;
void
parse
(
argument_parser
&
ap
)
{
l
.
parse
(
ap
);
ap
(
gpu
,
{
"--gpu"
},
ap
.
help
(
"Compile on the gpu"
),
ap
.
set_value
(
true
));
ap
(
gpu
,
{
"--cpu"
},
ap
.
help
(
"Compile on the cpu"
),
ap
.
set_value
(
false
));
ap
(
fill1
,
{
"--fill1"
},
ap
.
help
(
"Fill parameter with 1s"
),
ap
.
append
());
}
program
compile
()
...
...
@@ -94,7 +97,14 @@ struct compiler
return
p
;
}
auto
params
(
const
program
&
p
)
{
return
create_param_map
(
p
,
gpu
);
}
auto
params
(
const
program
&
p
)
{
program
::
parameter_map
m
;
for
(
auto
&&
s
:
fill1
)
m
[
s
]
=
fill_argument
(
p
.
get_parameter_shape
(
s
),
1
);
fill_param_map
(
m
,
p
,
gpu
);
return
m
;
}
};
struct
read
:
command
<
read
>
...
...
@@ -109,6 +119,19 @@ struct read : command<read>
}
};
struct
params
:
command
<
params
>
{
loader
l
;
void
parse
(
argument_parser
&
ap
)
{
l
.
parse
(
ap
);
}
void
run
()
{
auto
p
=
l
.
load
();
for
(
auto
&&
param
:
p
.
get_parameter_shapes
())
std
::
cout
<<
param
.
first
<<
": "
<<
param
.
second
<<
std
::
endl
;
}
};
struct
verify
:
command
<
verify
>
{
loader
l
;
...
...
src/driver/perf.cpp
View file @
8a9cfb25
...
...
@@ -11,6 +11,23 @@ namespace migraphx {
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
program
::
parameter_map
fill_param_map
(
program
::
parameter_map
&
m
,
const
program
&
p
,
bool
gpu
)
{
for
(
auto
&&
x
:
p
.
get_parameter_shapes
())
{
argument
&
arg
=
m
[
x
.
first
];
if
(
arg
.
empty
())
arg
=
generate_argument
(
x
.
second
);
#ifdef HAVE_GPU
if
(
gpu
)
arg
=
gpu
::
to_gpu
(
arg
);
#else
(
void
)
gpu
;
#endif
}
return
m
;
}
program
::
parameter_map
create_param_map
(
const
program
&
p
,
bool
gpu
)
{
program
::
parameter_map
m
;
...
...
src/driver/perf.hpp
View file @
8a9cfb25
...
...
@@ -7,6 +7,7 @@ namespace migraphx {
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
program
::
parameter_map
fill_param_map
(
program
::
parameter_map
&
m
,
const
program
&
p
,
bool
gpu
);
program
::
parameter_map
create_param_map
(
const
program
&
p
,
bool
gpu
=
true
);
void
compile_program
(
program
&
p
,
bool
gpu
=
true
);
...
...
src/generate.cpp
View file @
8a9cfb25
...
...
@@ -3,6 +3,17 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
argument
fill_argument
(
shape
s
,
unsigned
long
value
)
{
argument
result
;
s
.
visit_type
([
&
](
auto
as
)
{
using
type
=
typename
decltype
(
as
)
::
type
;
auto
v
=
fill_tensor_data
<
type
>
(
s
,
value
);
result
=
{
s
,
[
v
]()
mutable
{
return
reinterpret_cast
<
char
*>
(
v
.
data
());
}};
});
return
result
;
}
argument
generate_argument
(
shape
s
,
unsigned
long
seed
)
{
argument
result
;
...
...
src/include/migraphx/generate.hpp
View file @
8a9cfb25
...
...
@@ -87,6 +87,16 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
return
result
;
}
template
<
class
T
>
std
::
vector
<
T
>
fill_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
value
=
0
)
{
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
generate
(
result
.
begin
(),
result
.
end
(),
[
=
]
{
return
value
;
});
return
result
;
}
argument
fill_argument
(
shape
s
,
unsigned
long
value
=
0
);
argument
generate_argument
(
shape
s
,
unsigned
long
seed
=
0
);
literal
generate_literal
(
shape
s
,
unsigned
long
seed
=
0
);
...
...
src/include/migraphx/requires.hpp
View file @
8a9cfb25
...
...
@@ -23,9 +23,10 @@ using bool_c = std::integral_constant<bool, B>;
#ifdef CPPCHECK
#define MIGRAPHX_REQUIRES(...) class = void
#else
#define MIGRAPHX_REQUIRES(...) \
bool MIGRAPHX_REQUIRES_VAR() = true, \
typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() && (migraphx::and_<__VA_ARGS__>{})), \
#define MIGRAPHX_REQUIRES(...) \
long MIGRAPHX_REQUIRES_VAR() = __LINE__, \
typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() == __LINE__ && \
(migraphx::and_<__VA_ARGS__>{})), \
int>::type = 0
#endif
...
...
src/targets/gpu/quant_gemm.cpp
View file @
8a9cfb25
...
...
@@ -8,51 +8,6 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
template
<
class
...
Ts
>
rocblas_status
generic_rocblas_gemm_ex
(
Ts
&&
...
xs
)
{
return
rocblas_gemm_ex
(
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
...
Ts
>
rocblas_status
generic_rocblas_batched_gemm_ex
(
Ts
&&
...
xs
)
{
return
rocblas_gemm_strided_batched_ex
(
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
T
>
struct
compute_rocblas_type
{
using
type
=
T
;
};
template
<
class
T
>
struct
compute_rocblas_type
<
const
T
>
{
using
type
=
const
typename
compute_rocblas_type
<
T
>::
type
;
};
template
<
>
struct
compute_rocblas_type
<
half
>
{
using
type
=
rocblas_half
;
};
template
<
class
T
>
using
rb_type
=
typename
compute_rocblas_type
<
T
>::
type
;
template
<
class
T
>
rb_type
<
T
>
to_rocblas_type
(
T
x
)
{
return
reinterpret_cast
<
const
rb_type
<
T
>&>
(
x
);
}
template
<
class
T
>
rb_type
<
T
>*
to_rocblas_type
(
T
*
x
)
{
return
reinterpret_cast
<
rb_type
<
T
>*>
(
x
);
}
shape
rocblas_quant_gemm
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
std
::
vector
<
shape
>
in_shapes
(
inputs
);
...
...
@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx,
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
)
)
;
auto
beta_r
=
to_rocblas_type
(
as
(
beta
)
)
;
auto
alpha_r
=
as
(
op
.
alpha
);
auto
beta_r
=
as
(
beta
);
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
())
)
;
};
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
as
.
from
(
arg
.
data
());
};
assert
(
k
%
4
==
0
);
auto
num_matrices
=
std
::
accumulate
(
...
...
@@ -119,36 +74,36 @@ argument rocblas_quant_gemm::compute(context& ctx,
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
generic_
rocblas_gemm_ex
(
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha_r
,
to_pointer
(
args
.
at
(
1
)),
rocblas_datatype_i8_r
,
ldb
,
to_pointer
(
args
.
at
(
0
)),
rocblas_datatype_i8_r
,
lda
,
&
beta_r
,
to_pointer
(
args
[
2
]),
rocblas_datatype_i32_r
,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
rocblas_datatype_i32_r
,
ldc
,
rocblas_datatype_i32_r
,
rocblas_gemm_algo_standard
,
0
,
0
,
nullptr
,
nullptr
);
rocblas_gemm_ex
(
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha_r
,
to_pointer
(
args
.
at
(
1
)),
rocblas_datatype_i8_r
,
ldb
,
to_pointer
(
args
.
at
(
0
)),
rocblas_datatype_i8_r
,
lda
,
&
beta_r
,
to_pointer
(
args
[
2
]),
rocblas_datatype_i32_r
,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
rocblas_datatype_i32_r
,
ldc
,
rocblas_datatype_i32_r
,
rocblas_gemm_algo_standard
,
0
,
0
,
nullptr
,
nullptr
);
}
else
{
generic_rocblas
_batched_
gemm_
ex
(
rocblas_gemm_strided
_batched_ex
(
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment