Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8d21ccdf
Commit
8d21ccdf
authored
Jan 31, 2022
by
Khalique Ahmed
Browse files
formatting
parent
3df20646
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
6 deletions
+12
-6
src/targets/gpu/device/softmax.cpp
src/targets/gpu/device/softmax.cpp
+1
-1
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+4
-4
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+7
-1
No files found.
src/targets/gpu/device/softmax.cpp
View file @
8d21ccdf
...
@@ -22,7 +22,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
...
@@ -22,7 +22,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
hip_visit_all
(
result
,
arg
,
batch_shape
)([
&
](
auto
output
,
auto
input
,
auto
batch
)
{
hip_visit_all
(
result
,
arg
,
batch_shape
)([
&
](
auto
output
,
auto
input
,
auto
batch
)
{
const
index_int
max_block_size
=
120
;
const
index_int
max_block_size
=
120
;
// const index_int max_block_size = 128;
// const index_int max_block_size = 128;
const
index_int
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
const
index_int
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
type
init
=
lowest
();
type
init
=
lowest
();
...
...
src/targets/gpu/gemm_impl.cpp
View file @
8d21ccdf
...
@@ -83,12 +83,12 @@ void gemm_impl(context& ctx,
...
@@ -83,12 +83,12 @@ void gemm_impl(context& ctx,
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
as
(
alpha
);
auto
alpha_r
=
as
(
alpha
);
auto
beta_r
=
as
(
beta
);
auto
beta_r
=
as
(
beta
);
if
(
compute_fp32
)
if
(
compute_fp32
)
{
{
alpha_r
=
alpha
;
alpha_r
=
alpha
;
beta_r
=
beta
;
beta_r
=
beta
;
}
}
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
m
=
out_lens
[
dim_0
];
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
8d21ccdf
...
@@ -85,7 +85,13 @@ struct rocblas_gemm
...
@@ -85,7 +85,13 @@ struct rocblas_gemm
}
}
else
else
{
{
gemm
(
ctx
,
output_shape
,
args
,
int32_t
(
alpha
),
int32_t
(
beta
),
int8_x4_format
,
compute_fp32
);
gemm
(
ctx
,
output_shape
,
args
,
int32_t
(
alpha
),
int32_t
(
beta
),
int8_x4_format
,
compute_fp32
);
}
}
return
args
.
back
();
return
args
.
back
();
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment