Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5b4bb22c
"src/targets/gpu/device/sqdiff.cpp" did not exist on "96358e41cc883791c8d3ad50280bea4871a18000"
Commit
5b4bb22c
authored
Mar 01, 2019
by
Shucai Xiao
Browse files
extend the gemm implementation to support 3 arguments.
parent
5ded35b0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
138 additions
and
13 deletions
+138
-13
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+30
-4
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+9
-1
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+16
-0
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+15
-8
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+47
-0
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+21
-0
No files found.
src/include/migraphx/operators.hpp
View file @
5b4bb22c
...
...
@@ -810,7 +810,7 @@ struct gather
struct
dot
{
float
alpha
=
1.0
;
float
beta
=
0
.0
;
float
beta
=
1
.0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -821,7 +821,7 @@ struct dot
std
::
string
name
()
const
{
return
"dot"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
();
check_shapes
{
{
inputs
[
0
],
inputs
[
1
]}
,
*
this
}.
has
(
2
).
same_type
();
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
...
...
@@ -831,14 +831,40 @@ struct dot
// as long as dim values are the same in the two inputs
if
(
!
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
))
{
MIGRAPHX_THROW
(
"DOT: dim values mismatch"
);
MIGRAPHX_THROW
(
"DOT: number of matrices in stack are different in A and B"
);
}
if
(
inputs
.
size
()
==
3
)
{
check_shapes
{{
inputs
[
0
],
inputs
[
2
]},
*
this
}.
has
(
2
).
same_type
();
const
shape
&
c
=
inputs
.
at
(
2
);
if
(
!
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
c
.
lens
().
rbegin
()
+
2
))
{
MIGRAPHX_THROW
(
"DOT: number of matrices in stack are different in A and C"
);
}
}
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
MIGRAPHX_THROW
(
"
I
nner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
MIGRAPHX_THROW
(
"
DOT : i
nner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
if
(
inputs
.
size
()
==
3
)
{
const
shape
&
c
=
inputs
.
at
(
2
);
if
(
a
.
lens
()[
dim_0
]
!=
c
.
lens
()[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT : matrix size does not match: A: {"
+
to_string_range
(
a
.
lens
())
+
"}, C: {"
+
to_string_range
(
c
.
lens
())
+
"}"
);
}
if
(
b
.
lens
()[
dim_1
]
!=
c
.
lens
()[
dim_1
])
{
MIGRAPHX_THROW
(
"DOT : matrix size does not match: B: {"
+
to_string_range
(
b
.
lens
())
+
"}, C: {"
+
to_string_range
(
c
.
lens
())
+
"}"
);
}
}
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
return
{
t
,
out_lens
};
...
...
src/targets/cpu/gemm.cpp
View file @
5b4bb22c
...
...
@@ -55,7 +55,15 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat
(
amat
,
[
&
](
const
auto
&
a
)
{
visit_mat
(
bmat
,
[
&
](
const
auto
&
b
)
{
auto
c
=
make_mat
(
cmat
);
c
=
(
a
*
b
)
*
alpha
+
beta
*
c
;
if
(
beta
!=
0.0
)
{
c
=
beta
*
c
;
}
if
(
alpha
!=
0.0
)
{
c
=
c
+
alpha
*
a
*
b
;
}
});
});
}
...
...
src/targets/cpu/lowering.cpp
View file @
5b4bb22c
...
...
@@ -374,6 +374,22 @@ struct cpu_gemm
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
// if there is a C input
if
(
args
.
size
()
==
3
)
{
result
.
visit
([
&
](
auto
output
)
{
args
[
2
].
visit
([
&
](
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
});
}
else
{
result
.
visit
([
&
](
auto
output
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
});
}
migemm
(
result
,
args
[
0
],
args
[
1
],
op
.
alpha
,
op
.
beta
);
return
result
;
}
...
...
src/targets/gpu/gemm.cpp
View file @
5b4bb22c
...
...
@@ -90,15 +90,12 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal
shape
miopen_gemm
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
return
op
.
compute_shape
(
inputs
);
}
argument
miopen_gemm
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
std
::
size_t
n_dims
=
args
[
0
].
get_shape
().
lens
().
size
();
...
...
@@ -113,9 +110,19 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
auto
batch_num
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
bool
is_3inputs
=
(
args
.
size
()
==
4
);
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
if
(
is_3inputs
)
hipMemcpy
(
to_pointer
(
args
[
3
]),
to_pointer
(
args
[
2
]),
output_shape
.
bytes
(),
hipMemcpyDeviceToDevice
);
else
hipMemset
(
to_pointer
(
args
[
2
]),
0
,
output_shape
.
bytes
());
});
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
op
.
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_batched_gemm
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
...
...
@@ -132,14 +139,14 @@ argument miopen_gemm::compute(context& ctx,
lda
,
m
*
k
,
&
beta_r
,
to_pointer
(
args
[
2
]),
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
ldc
,
m
*
n
,
batch_num
);
});
return
args
[
2
];
return
(
is_3inputs
?
args
[
3
]
:
args
[
2
]
)
;
}
}
// namespace gpu
...
...
test/cpu_ops_test.cpp
View file @
5b4bb22c
...
...
@@ -1112,6 +1112,53 @@ TEST_CASE(gemm_mutli_dim1_2_3)
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
gemm_mutli_3args
)
{
migraphx
::
program
p
;
std
::
vector
<
float
>
m1
=
{
1.23636469
,
-
0.47041261
,
-
0.14375651
,
-
0.48371852
,
1.16479301
,
-
0.89361055
,
-
0.18569086
,
1.10700457
,
-
1.02632638
,
0.82277012
,
0.33525769
,
0.52825145
,
-
1.00141689
,
0.45510090
,
-
0.02675039
,
-
0.60454439
,
0.38551153
,
-
0.01658514
,
0.93059292
,
-
0.54595188
,
-
0.04911005
,
-
0.91397221
,
-
0.83127477
,
-
1.57685603
,
-
1.36200452
,
2.25822236
,
-
1.23416970
,
0.12312496
,
0.76232760
,
-
0.83594234
,
1.67418145
,
-
0.19412936
,
1.05261378
,
0.66246074
,
-
1.15233398
,
0.16429736
};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
2
,
3
}};
std
::
vector
<
float
>
m2
=
{
-
0.87300530
,
-
0.07112838
,
0.19196860
,
-
1.04986840
,
1.20348200
,
0.31966893
,
1.04805440
,
-
2.04777729
,
-
0.67906052
,
-
1.17250760
,
0.34305044
,
-
1.01957785
,
-
1.12694862
,
0.18431338
,
-
1.63712290
,
0.27566931
,
-
1.11282021
,
1.41738919
,
0.47871283
,
-
1.01980420
,
1.00212436
,
-
0.78740444
,
-
1.65636133
,
1.51466547
,
-
0.12470397
,
0.70404393
,
-
0.15244797
,
0.74288871
,
0.07339926
,
-
1.45811623
,
0.27185845
,
0.08804596
,
0.99061977
,
-
1.61752428
,
0.29191159
,
0.87271953
};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
2
}};
std
::
vector
<
float
>
m3
=
{
-
1.07692443
,
0.85223457
,
-
0.37266530
,
2.31511577
,
0.04227017
,
1.13229428
,
-
0.52769242
,
0.27307182
,
-
0.47779843
,
-
0.08023168
,
-
0.22862823
,
0.81489871
,
1.13139581
,
1.13860467
,
0.24309065
,
0.26533729
,
0.49106772
,
-
1.18860493
,
0.27842449
,
1.03568141
,
0.49759611
,
0.10021662
,
0.00592602
,
0.90862000
};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
2
,
2
}};
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
m1
});
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
m2
});
auto
l3
=
p
.
add_literal
(
migraphx
::
literal
{
m3_shape
,
m3
});
float
alpha
=
0.35
;
float
beta
=
0.41
;
p
.
add_instruction
(
migraphx
::
op
::
dot
{
alpha
,
beta
},
l1
,
l2
,
l3
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
m_res
=
{
-
0.91147203
,
0.47540785
,
-
0.30313587
,
0.43325099
,
-
0.43711586
,
0.50928632
,
0.06919868
,
-
0.80382802
,
-
0.05125718
,
-
0.06685650
,
-
0.06972163
,
0.32407764
,
0.45677396
,
0.25909489
,
0.56911252
,
-
0.17183724
,
0.10858734
,
0.39406289
,
0.04662959
,
1.07979824
,
0.40355016
,
0.52410648
,
-
0.31728447
,
1.09550845
};
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
maxpool_test
)
{
migraphx
::
program
p
;
...
...
test/gpu/miopen.cpp
View file @
5b4bb22c
...
...
@@ -882,6 +882,26 @@ struct gemm_mutli_dim_2_3
}
};
struct
gemm_mutli_3args
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
2
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
2
,
2
}};
auto
l1
=
p
.
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
p
.
add_parameter
(
"2"
,
m2_shape
);
auto
l3
=
p
.
add_parameter
(
"3"
,
m3_shape
);
float
alpha
=
0.35
;
float
beta
=
0.41
;
p
.
add_instruction
(
migraphx
::
op
::
dot
{
alpha
,
beta
},
l1
,
l2
,
l3
);
return
p
;
}
};
struct
test_contiguous
{
migraphx
::
program
create_program
()
const
...
...
@@ -3016,6 +3036,7 @@ int main()
verify_program
<
test_gemm_transposeab
>
();
verify_program
<
gemm_mutli_dim_2
>
();
verify_program
<
gemm_mutli_dim_2_3
>
();
verify_program
<
gemm_mutli_3args
>
();
verify_program
<
test_contiguous
>
();
verify_program
<
test_eliminate_contiguous
>
();
verify_program
<
test_transpose
>
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment