Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3ccd7e15
"...resnet50_tensorflow.git" did not exist on "5b0c1edca93ad31dc6a37c9c15c34f65b609c867"
Commit
3ccd7e15
authored
Mar 13, 2019
by
Shucai Xiao
Browse files
code backup for the gemm implementation
parent
b9e0366d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
387 additions
and
222 deletions
+387
-222
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+49
-53
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+12
-12
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+74
-50
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+251
-104
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+1
-3
No files found.
src/include/migraphx/operators.hpp
View file @
3ccd7e15
...
@@ -819,10 +819,11 @@ struct gather
...
@@ -819,10 +819,11 @@ struct gather
// vector input; if A or B is 2-dim, it is a matrix (no case of a batch of
// vector input; if A or B is 2-dim, it is a matrix (no case of a batch of
// vectors as input). If A or B is 3 or more dims, it is considered as a
// vectors as input). If A or B is 3 or more dims, it is considered as a
// stack(batch) of matrices.
// stack(batch) of matrices.
// Note that, we optimze the scenario of either the Matmul or Gemm operators,
// Note that we only support the scenario of either the Matmul or Gemm
// But for extensional scenarios like GEMM with three inputs, and each arg
// operators. That is, if there are 3 inputs, we consider it is a Gemm, then
// is a batch is matrices, the implementation may need further optimization
// A and B must be matrix inputs, and C is broadcastable to A * B. If there
// later.
// is only two inputs, A and B can be 1-dim to N-dim, in this case, there
// is no C input.
struct
dot
struct
dot
{
{
float
alpha
=
1.0
;
float
alpha
=
1.0
;
...
@@ -844,25 +845,12 @@ struct dot
...
@@ -844,25 +845,12 @@ struct dot
if
(
a
.
empty
())
if
(
a
.
empty
())
{
{
if
(
is_mutli_broadcast
)
return
b
;
{
return
b
;
}
else
{
MIGRAPHX_THROW
(
"DOT: C is not broadcastable to A * B (scalar)"
);
}
}
}
auto
a_size
=
a
.
size
();
auto
a_size
=
a
.
size
();
auto
b_size
=
b
.
size
();
auto
b_size
=
b
.
size
();
if
(
is_mutli_broadcast
&&
b_size
>
a_size
)
{
MIGRAPHX_THROW
(
"DOT: C {"
+
to_string_range
(
b
)
+
"} is not broadcastable to A * b {"
+
to_string_range
(
a
)
+
"}"
);
}
auto
n_dim
=
std
::
min
(
a_size
,
b_size
);
auto
n_dim
=
std
::
min
(
a_size
,
b_size
);
std
::
vector
<
std
::
size_t
>
out_lens
(
std
::
max
(
a_size
,
b_size
));
std
::
vector
<
std
::
size_t
>
out_lens
(
std
::
max
(
a_size
,
b_size
));
for
(
std
::
size_t
i
=
0
;
i
<
n_dim
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
n_dim
;
++
i
)
...
@@ -875,27 +863,15 @@ struct dot
...
@@ -875,27 +863,15 @@ struct dot
{
{
out_lens
[
i
]
=
a
[
a_size
-
1
-
i
];
out_lens
[
i
]
=
a
[
a_size
-
1
-
i
];
}
}
else
if
(
a
[
a_size
-
1
-
i
]
==
1
&&
is_mutli_broadcast
)
{
out_lens
[
i
]
=
b
[
b_size
-
1
-
i
];
}
else
else
{
{
if
(
a
[
a_size
-
1
-
i
]
==
1
&&
is_mutli_broadcast
)
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
{
to_string_range
(
a
)
+
"}, and matrix B: {"
+
out_lens
[
i
]
=
b
[
b_size
-
1
-
i
];
to_string_range
(
b
)
+
"} are not broadcastable"
);
}
else
{
if
(
is_mutli_broadcast
)
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a
)
+
"}, and matrix B: {"
+
to_string_range
(
b
)
+
"} are not broadcastable"
);
}
else
{
MIGRAPHX_THROW
(
"DOT: C {"
+
to_string_range
(
b
)
+
"} is not broadcastable to A * b {"
+
to_string_range
(
a
)
+
"}"
);
}
}
}
}
}
}
...
@@ -917,7 +893,42 @@ struct dot
...
@@ -917,7 +893,42 @@ struct dot
std
::
string
name
()
const
{
return
"dot"
;
}
std
::
string
name
()
const
{
return
"dot"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{{
inputs
[
0
],
inputs
[
1
]},
*
this
}.
has
(
2
).
same_type
();
// If there are 3 inputs, then A and B must be matrices and
// C is broadcastable to A * B
if
(
inputs
.
size
()
==
3
)
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
same_type
();
check_shapes
{{
inputs
[
0
]},
*
this
}.
only_dims
(
2
);
check_shapes
{{
inputs
[
1
]},
*
this
}.
only_dims
(
2
);
auto
a_lens
=
inputs
[
0
].
lens
();
auto
b_lens
=
inputs
[
1
].
lens
();
auto
t
=
inputs
[
0
].
type
();
if
(
a_lens
[
1
]
!=
b_lens
[
0
])
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, operand A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply operand B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
}
auto
out_lens
=
a_lens
;
out_lens
[
0
]
=
b_lens
[
0
];
// check whether C is broadcastable to A * B
auto
c_lens
=
inputs
[
2
].
lens
();
if
(
c_lens
.
size
()
>
2
||
(
c_lens
.
size
()
>=
1
&&
(
c_lens
[
0
]
!=
1
&&
c_lens
[
0
]
!=
b_lens
[
0
]))
||
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
1
]
!=
1
&&
c_lens
[
1
]
!=
a_lens
[
1
])))
{
MIGRAPHX_THROW
(
"DOT: C {"
+
to_string_range
(
c_lens
)
+
"} is not broadcastable to A * B {"
+
to_string_range
(
out_lens
)
+
"}"
);
}
return
{
t
,
out_lens
};
}
// For the case of two inputs, it is the numpy.matmul
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
();
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
auto
t
=
a
.
type
();
...
@@ -977,21 +988,6 @@ struct dot
...
@@ -977,21 +988,6 @@ struct dot
out_lens
.
pop_back
();
out_lens
.
pop_back
();
}
}
// c is unibroadcastable to A * B
if
(
inputs
.
size
()
==
3
)
{
// same type as A and B
check_shapes
{{
inputs
[
0
],
inputs
[
2
]},
*
this
}.
has
(
2
).
same_type
();
if
(
out_lens
.
empty
()
&&
(
!
inputs
[
2
].
scalar
()))
{
MIGRAPHX_THROW
(
"DOT: C is not broadcastable to A*B (scalar)"
);
}
// check c is broadcastable to A * B
auto
c_lens
=
inputs
[
2
].
lens
();
shape_broadcast
(
out_lens
,
c_lens
,
false
);
}
if
(
out_lens
.
empty
())
if
(
out_lens
.
empty
())
{
{
return
{
t
};
return
{
t
};
...
...
src/targets/cpu/gemm.cpp
View file @
3ccd7e15
...
@@ -77,19 +77,19 @@ void migemm_impl(tensor_view<T> cmat,
...
@@ -77,19 +77,19 @@ void migemm_impl(tensor_view<T> cmat,
auto
b_lens
=
bmat
.
get_shape
().
lens
();
auto
b_lens
=
bmat
.
get_shape
().
lens
();
auto
c_lens
=
cmat
.
get_shape
().
lens
();
auto
c_lens
=
cmat
.
get_shape
().
lens
();
std
::
size_t
n_dims
=
c_lens
.
size
();
std
::
size_t
n
c
_dims
=
c_lens
.
size
();
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
na_dims
=
a_lens
.
size
()
;
std
::
size_t
dim_1
=
n_dims
-
1
;
std
::
size_t
nb_dims
=
b_lens
.
size
()
;
auto
k
=
a_lens
[
dim_
1
];
auto
k
=
a_lens
[
na_dims
-
1
];
assert
(
a_lens
[
dim_
1
]
==
b_lens
[
dim_0
]);
assert
(
a_lens
[
na_dims
-
1
]
==
b_lens
[
nb_dims
-
1
]);
assert
(
c_lens
[
dim_0
]
==
a_lens
[
dim_0
]);
assert
(
c_lens
[
nc_dims
-
2
]
==
a_lens
[
na_dims
-
2
]);
assert
(
c_lens
[
dim_
1
]
==
b_lens
[
dim_
1
]);
assert
(
c_lens
[
nc_dims
-
1
]
==
b_lens
[
nb_dims
-
1
]);
std
::
size_t
a_len_diff
=
c_
lens
.
size
()
-
a_
lens
.
size
()
;
std
::
size_t
a_len_diff
=
n
c_
dims
-
n
a_
dims
;
std
::
size_t
b_len_diff
=
c_
lens
.
size
()
-
b_
lens
.
size
()
;
std
::
size_t
b_len_diff
=
n
c_
dims
-
n
b_
dims
;
std
::
vector
<
std
::
size_t
>
a_idx
(
a_
lens
.
size
()
);
std
::
vector
<
std
::
size_t
>
a_idx
(
n
a_
dims
);
std
::
vector
<
std
::
size_t
>
b_idx
(
b_
lens
.
size
()
);
std
::
vector
<
std
::
size_t
>
b_idx
(
n
b_
dims
);
shape_for_each
(
cmat
.
get_shape
(),
[
&
](
const
auto
&
c_idx
)
{
shape_for_each
(
cmat
.
get_shape
(),
[
&
](
const
auto
&
c_idx
)
{
std
::
transform
(
c_lens
.
begin
()
+
a_len_diff
,
std
::
transform
(
c_lens
.
begin
()
+
a_len_diff
,
...
@@ -105,7 +105,7 @@ void migemm_impl(tensor_view<T> cmat,
...
@@ -105,7 +105,7 @@ void migemm_impl(tensor_view<T> cmat,
double
s
=
0.0
;
double
s
=
0.0
;
dfor
(
k
)([
&
](
auto
kk
)
{
dfor
(
k
)([
&
](
auto
kk
)
{
a_idx
[
dim_
1
]
=
b_idx
[
dim_0
]
=
kk
;
a_idx
[
na_dims
-
1
]
=
b_idx
[
nb_dims
-
2
]
=
kk
;
s
+=
amat
(
a_idx
.
begin
(),
a_idx
.
end
())
*
bmat
(
b_idx
.
begin
(),
b_idx
.
end
());
s
+=
amat
(
a_idx
.
begin
(),
a_idx
.
end
())
*
bmat
(
b_idx
.
begin
(),
b_idx
.
end
());
});
});
cmat
(
c_idx
.
begin
(),
c_idx
.
end
())
=
alpha
*
s
+
cmat
(
c_idx
.
begin
(),
c_idx
.
end
())
*
beta
;
cmat
(
c_idx
.
begin
(),
c_idx
.
end
())
=
alpha
*
s
+
cmat
(
c_idx
.
begin
(),
c_idx
.
end
())
*
beta
;
...
...
src/targets/cpu/lowering.cpp
View file @
3ccd7e15
...
@@ -371,14 +371,82 @@ struct cpu_gemm
...
@@ -371,14 +371,82 @@ struct cpu_gemm
std
::
string
name
()
const
{
return
"cpu::dot"
;
}
std
::
string
name
()
const
{
return
"cpu::dot"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
void
fill_result
(
argument
&
result
,
argument
&
c
)
const
{
auto
out_lens
=
result
.
get_shape
().
lens
();
auto
c_lens
=
c
.
get_shape
().
lens
();
if
(
out_lens
==
c_lens
)
{
visit_all
(
result
,
c
)([
&
](
auto
output
,
auto
input
)
{
std
::
memcpy
(
output
.
data
(),
input
.
data
(),
c_shape
.
bytes
());
});
}
// need broadcast
else
if
(
c
.
single
())
{
visit_all
(
result
,
c
)([
&
](
auto
output
,
auto
input
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
input
.
front
());
});
}
// must be c_lens[0] == output_lens[1]
else
if
(
c_lens
.
size
()
==
1
||
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
1
]
==
out_lens
[
1
])))
{
std
::
size_t
m
=
out_lens
[
0
];
std
::
size_t
n
=
out_lens
[
1
];
visit_all
(
result
,
c
)([
&
](
auto
output
,
auto
input
)
{
for
(
std
::
size_t
i
=
0
;
i
<
m
;
i
++
)
{
std
::
memcpy
((
output
.
data
()
+
i
*
n
),
input
.
data
(),
c_shape
.
bytes
());
}
});
}
// c_lens.size() == 2 and c_lens[0] == out_lens[0]
else
{
std
::
size_t
m
=
out_lens
[
0
];
std
::
size_t
n
=
out_lens
[
1
];
visit_all
(
result
,
c
)([
&
](
auto
output
,
auto
input
)
{
for
(
std
::
size_t
i
=
0
;
i
<
m
;
i
++
)
{
std
::
fill
(
output
.
begin
()
+
i
*
n
,
(
i
+
1
==
m
)
?
output
.
end
()
:
output
.
begin
()
+
((
i
+
1
)
*
n
),
input
[
i
]);
}
});
}
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
// 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrics, and C is broadcastable to A * B
if
(
args
.
size
()
==
3
)
{
// no need to consider the value of args[2]
if
(
op
.
beta
==
0.0
f
)
{
result
.
visit
([
&
](
auto
output
)
{
std
::
memset
(
output
.
data
(),
0
,
output_shape
.
bytes
());
});
}
else
{
fill_result
(
result
,
args
[
2
]);
}
migemm
(
result
,
args
[
0
],
args
[
1
],
op
.
alpha
,
op
.
beta
);
return
result
;
}
// 2 input cases
// all args are scalar
// all args are scalar
if
(
output_shape
.
scalar
())
if
(
output_shape
.
scalar
())
{
{
visit_all
(
result
,
args
[
0
],
args
[
1
]
,
args
[
2
]
)([
&
](
auto
re
t
,
auto
a
,
auto
b
,
auto
c
)
{
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
re
s
,
auto
a
,
auto
b
)
{
re
t
[
0
]
=
op
.
alpha
*
a
[
0
]
*
b
[
0
]
+
op
.
beta
*
c
[
0
]
;
re
s
[
0
]
=
op
.
alpha
*
a
[
0
]
*
b
[
0
];
});
});
return
result
;
return
result
;
...
@@ -406,55 +474,11 @@ struct cpu_gemm
...
@@ -406,55 +474,11 @@ struct cpu_gemm
out_lens
.
push_back
(
1
);
out_lens
.
push_back
(
1
);
}
}
// if there is a C input
if
(
args
.
size
()
==
2
)
{
result
.
visit
([
&
](
auto
output
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
});
migemm
({{
t
,
out_lens
},
result
.
data
()},
{{
t
,
a_lens
},
args
[
0
].
data
()},
{{
t
,
b_lens
},
args
[
1
].
data
()},
op
.
alpha
,
op
.
beta
);
return
result
;
}
// 3 input arguments
auto
c_shape
=
args
[
2
].
get_shape
();
// In GEMM, C is broadcastable to A * B, so we should consider C
// is not the same shape as A * B. If the same shape, copy C to
// the memory of the output
if
(
c_shape
==
output_shape
)
{
// memory copy is more efficient than doing element by element
result
.
visit
([
&
](
auto
output
)
{
args
[
2
].
visit
(
[
&
](
auto
input
)
{
std
::
memcpy
(
output
.
data
(),
input
.
data
(),
c_shape
.
bytes
());
});
});
}
else
{
auto
out_len
=
output_shape
.
lens
();
auto
c_lens
=
c_shape
.
lens
();
std
::
size_t
len_diff
=
out_len
.
size
()
-
c_lens
.
size
();
visit_all
(
result
,
args
[
2
])([
&
](
auto
output
,
auto
c
)
{
shape_for_each
(
output_shape
,
[
&
](
auto
out_idx
)
{
// compute the input index
std
::
vector
<
std
::
size_t
>
in_idx
(
c_lens
.
size
());
std
::
transform
(
c_lens
.
begin
(),
c_lens
.
end
(),
out_len
.
begin
()
+
len_diff
,
in_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
i
==
1
)
?
0
:
j
;
});
output
(
out_idx
.
begin
(),
out_idx
.
end
())
=
c
(
in_idx
.
begin
(),
in_idx
.
end
());
});
});
}
migemm
({{
t
,
out_lens
},
result
.
data
()},
migemm
({{
t
,
out_lens
},
result
.
data
()},
{{
t
,
a_lens
},
args
[
0
].
data
()},
{{
t
,
a_lens
},
args
[
0
].
data
()},
{{
t
,
b_lens
},
args
[
1
].
data
()},
{{
t
,
b_lens
},
args
[
1
].
data
()},
op
.
alpha
,
op
.
alpha
,
op
.
beta
);
0.0
f
);
return
result
;
return
result
;
}
}
...
...
src/targets/gpu/gemm.cpp
View file @
3ccd7e15
...
@@ -171,10 +171,75 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
...
@@ -171,10 +171,75 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
return
op
.
compute_shape
(
inputs
);
return
op
.
compute_shape
(
inputs
);
}
}
std
::
size_t
miopen_gemm
::
compute_offset
(
std
::
vector
<
std
::
size_t
>&
out_lens
,
void
miopen_gemm
::
fill_result
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
size_t
index
,
const
argument
&
result
,
const
argument
&
c
)
const
std
::
vector
<
std
::
size_t
>&
data_lens
)
const
{
{
auto
out_lens
=
output_shape
.
lens
();
auto
c_lens
=
c
.
get_shape
().
lens
();
if
(
output_shape
==
c
.
get_shape
())
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
hipMemcpy
(
to_pointer
(
args
[
3
]),
to_pointer
(
args
[
2
]),
output_shape
.
bytes
(),
hipMemcpyDeviceToDevice
);
});
}
else
if
(
c
.
single
())
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
for
(
std
::
size_t
i
=
0
;
i
<
output_shape
.
elements
();
++
i
)
{
hipMemcpy
(
to_pointer
(
args
[
3
],
i
),
to_pointer
(
args
[
2
]),
args
[
2
].
get_shape
().
bytes
(),
hipMemcpyDeviceToDevice
);
}
});
}
else
if
(
c_lens
.
size
()
==
1
||
(
c_lens
.
size
()
==
2
&&
c_lens
[
1
]
==
out_lens
[
1
]))
{
auto
m
=
out_lens
[
0
];
auto
n
=
out_lens
[
1
];
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
for
(
std
::
size_t
i
=
0
;
i
<
m
;
++
i
)
{
hipMemcpy
(
to_pointer
(
args
[
3
],
i
*
n
),
to_pointer
(
args
[
2
]),
args
[
2
].
get_shape
().
bytes
(),
hipMemcpyDeviceToDevice
);
}
});
}
// case of c_lens.size() == 2 && c_len[0] == out_lens[0]
else
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
for
(
std
::
size_t
i
=
0
;
i
<
output_shape
.
elements
();
++
i
)
{
hipMemcpy
(
to_pointer
(
args
[
3
],
i
),
to_pointer
(
args
[
2
],
i
/
n
),
args
[
2
].
get_shape
().
type_size
(),
hipMemcpyDeviceToDevice
);
}
});
}
}
}
argument
miopen_gemm
::
compute
(
context
&
ctx
,
argument
miopen_gemm
::
compute
(
context
&
ctx
,
...
@@ -182,12 +247,51 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -182,12 +247,51 @@ argument miopen_gemm::compute(context& ctx,
const
std
::
vector
<
argument
>&
args
)
const
const
std
::
vector
<
argument
>&
args
)
const
{
{
bool
is_3inputs
=
(
args
.
size
()
==
4
);
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
is_3inputs
)
{
fill_result
(
ctx
,
output_shape
,
args
[
3
],
args
[
2
]);
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
op
.
beta
));
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
1
:
0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
1
:
0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
0
];
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
0
];
rocblas_int
n
=
out_lens
[
1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
1
];
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_gemm
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha_r
,
to_pointer
(
args
[
1
]),
ldb
,
to_pointer
(
args
[
0
]),
lda
,
&
beta_r
,
to_pointer
(
args
[
2
]),
ldc
);
});
return
args
[
3
];
}
// 2 input arguments cases
// vector inner product
if
(
output_shape
.
elements
()
==
1
)
if
(
output_shape
.
elements
()
==
1
)
{
{
assert
(
args
[
0
].
get_shape
().
elements
()
==
args
[
1
].
get_shape
().
elements
());
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
op
.
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_dot
(
as
,
generic_rocblas_dot
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
...
@@ -196,129 +300,172 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -196,129 +300,172 @@ argument miopen_gemm::compute(context& ctx,
1
,
1
,
to_pointer
(
args
[
1
]),
to_pointer
(
args
[
1
]),
1
,
1
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]));
to_pointer
(
args
[
2
]));
generic_rocblas_scal
(
as
,
generic_rocblas_scal
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
1
,
1
,
&
alpha_r
,
&
alpha_r
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]));
to_pointer
(
args
[
2
]));
1
);
1
);
if
(
is_3inputs
)
{
generic_rocblas_axpy
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
1
,
&
beta_r
,
to_pointer
(
args
[
2
]),
1
,
to_pointer
(
args
[
3
]),
1
);
}
});
});
return
is_3inputs
?
args
[
3
]
:
args
[
2
];
}
}
// matrix * vector
// b is a vector, so the computation is matrix * vector
else
if
(
args
[
1
].
get_shape
().
lens
().
size
()
==
1
)
// could not be the case of inner product of vectors since
// it is already processed above
if
(
args
[
1
].
get_shape
().
lens
().
size
()
==
1
)
{
{
// considering the batch input, so A could be a batch
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
// of matrices
std
::
size_t
dim_0
=
a_lens
.
size
()
-
2
;
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
std
::
size_t
dim_1
=
a_lens
.
size
()
-
1
;
std
::
size_t
n_dims
=
a_lens
.
size
();
bool
trans
=
args
[
0
].
get_shape
().
transposed
();
std
::
size_t
dim_0
=
n_dims
-
2
;
rocblas_int
m
=
a_lens
[
trans
?
dim_1
:
dim_0
];
std
::
size_t
dim_1
=
n_dims
-
1
;
rocblas_int
n
=
a_lens
[
trans
?
dim_0
:
dim_1
];
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
float
beta
=
0.0
f
;
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
trans
?
dim_1
:
dim_0
];
rocblas_int
m
=
a_lens
[
dim_0
];
rocblas_int
k
=
a_lens
[
dim_1
];
assert
(
a_lens
.
back
()
==
args
[
1
].
get_shape
().
elements
());
auto
batch_num
=
std
::
accumulate
(
std
::
size_t
batch_num
=
std
::
accumulate
(
a_lens
.
rbegin
()
+
2
,
a_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
a_lens
.
rbegin
()
+
2
,
a_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
op
.
beta
));
auto
beta_r
=
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
)
{
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
for
(
std
::
size_t
batch_no
=
0
;
batch_no
<
batch_num
;
++
batch_no
)
};
for
(
std
::
size_t
batch_no
=
0
;
batch_no
<
batch_num
;
++
batch_no
)
{
{
if
(
is_3inputs
)
generic_rocblas_gemv
(
as
,
hipMemcpy
(
to_pointer
(
args
[
3
]
+
batch_no
*
m
),
ctx
.
get_stream
().
get_rocblas
(),
to_pointer
(
args
[
2
]),
trans
?
rocblas_operation_transpose
:
rocblas_operation_none
,
output_shape
.
bytes
(),
m
,
hipMemcpyDeviceToDevice
);
n
,
else
&
alpha_r
,
hipMemset
(
to_pointer
(
args
[
2
]),
0
,
output_shape
.
bytes
());
to_pointer
(
args
[
0
],
batch_no
*
m
*
n
),
lda
,
to_pointer
(
args
[
1
]),
1
,
&
beta_r
,
to_pointer
(
args
[
2
],
batch_no
*
n
)
1
);
}
}
});
});
}
}
// vector * matrix
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
else
if
(
args
[
0
].
get_shape
().
lens
().
size
()
==
1
)
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
std
::
size_t
n_dims
=
args
[
0
].
get_shape
().
lens
().
size
();
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_1
=
n_dims
-
1
;
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
auto
batch_num
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
bool
is_3inputs
=
(
args
.
size
()
==
4
);
// two input arguments
if
(
!
is_3inputs
)
{
{
}
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
std
::
size_t
dim_0
=
b_lens
.
size
()
-
2
;
output_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
size_t
dim_1
=
b_lens
.
size
()
-
1
;
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
bool
trans
=
!
args
[
1
].
get_shape
().
transposed
();
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
rocblas_int
m
=
b_lens
[
trans
?
dim_1
:
dim_0
];
};
rocblas_int
n
=
b_lens
[
trans
?
dim_0
:
dim_1
];
if
(
is_3inputs
)
float
beta
=
0.0
f
;
hipMemcpy
(
to_pointer
(
args
[
3
]),
rocblas_int
lda
=
args
[
1
].
get_shape
().
strides
()[
trans
?
dim_1
:
dim_0
];
to_pointer
(
args
[
2
]),
output_shape
.
bytes
(),
assert
(
b_lens
.
back
()
==
args
[
0
].
get_shape
().
elements
());
hipMemcpyDeviceToDevice
);
std
::
size_t
batch_num
=
std
::
accumulate
(
b_lens
.
rbegin
()
+
2
,
b_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
else
output_shape
.
visit_type
([
&
](
auto
as
)
{
hipMemset
(
to_pointer
(
args
[
2
]),
0
,
output_shape
.
bytes
());
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
});
auto
beta_r
=
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
output_shape
.
visit_type
([
&
](
auto
as
)
{
for
(
std
::
size_t
batch_no
=
0
;
batch_no
<
batch_num
;
++
batch_no
)
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
generic_rocblas_gemv
(
as
,
};
generic_rocblas_batched_gemm
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
trans
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
n
,
m
,
m
,
k
,
&
alpha_r
,
&
alpha_r
,
to_pointer
(
args
[
1
]),
ldb
,
k
*
n
,
to_pointer
(
args
[
0
]),
to_pointer
(
args
[
0
]),
lda
,
lda
,
m
*
k
,
to_pointer
(
args
[
1
],
batch_no
*
m
*
n
),
1
,
&
beta_r
,
&
beta_r
,
to_pointer
(
args
[
2
]),
to_pointer
(
args
[
2
],
batch_no
*
m
)
ldc
,
1
);
m
*
n
,
}
batch_num
);
});
});
}
// (batch) matrix multiplication
else
{
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
a_lens
.
size
()
-
1
:
a_lens
.
size
()
-
2
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
b_lens
.
size
()
-
1
:
b_lens
.
size
()
-
2
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
out_lens
.
size
()
-
2
];
rocblas_int
m
=
out_lens
[
out_lens
.
size
()
-
2
];
rocblas_int
n
=
out_lens
[
out_lens
.
size
()
-
1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
a_lens
.
size
()
-
1
];
auto
input_dims
=
std
::
min
(
a_lens
.
size
(),
b_lens
.
size
());
std
::
size_t
axis
{
0
};
for
(
axis
=
2
;
axis
<
input_dims
;
++
axis
)
{
if
(
a_lens
[
a_lens
.
size
()
-
axis
]
!=
b_lens
[
b_lens
.
size
()
-
axis
])
{
break
;
}
}
// The number of matrices that can be computed in one call
// batch_num > 1, we need to call the batch_gemm function,
// otherwise, call the gemm function directly
std
::
size_t
num_matrices
=
std
::
accumulate
(
a_lens
.
rbegin
()
+
2
,
(
axis
==
a_lens
.
size
()
?
a_lens
.
rend
()
:
a_lens
.
rbegin
()
+
axis
),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
a_len_diff
=
out_lens
.
size
()
-
a_lens
.
size
();
std
::
size_t
b_len_diff
=
out_lens
.
size
()
-
b_lens
.
size
();
std
::
vector
<
std
::
size_t
>
a_batch_lens
(
a_lens
.
begin
(),
a_lens
.
begin
()
+
a_lens
.
size
()
-
axis
);
std
::
vector
<
std
::
size_t
>
b_batch_lens
(
b_lens
.
begin
(),
b_lens
.
begin
()
+
b_lens
.
size
()
-
axis
);
std
::
vector
<
std
::
size_t
>
out_batch_lens
(
out_lens
.
begin
(),
out_lens
.
begin
()
+
out_lens
.
size
()
-
axis
);
shape
::
type_t
t
=
output_shape
.
type
();
shape
a_batch_shape
{
t
,
a_batch_lens
};
shape
b_batch_shape
{
t
,
b_batch_lens
};
shape
out_diff_shape
{
t
,
out_batch_lens
};
shape_for_each
(
out_diff_shape
,
[
&
](
auto
out_idx
)
{
std
::
size_t
out_ind
=
out_batch_shape
.
index
(
out_idx
.
begin
(),
out_idx
.
end
());
std
::
vector
<
std
::
size_t
>
a_idx
(
a_lens
.
size
()
-
axis
);
std
::
vector
<
std
::
size_t
>
b_idx
(
b_lens
.
size
()
-
axis
);
std
::
transform
(
out_idx
.
begin
()
+
a_len_diff
,
out_idx
.
end
(),
a_batch_lens
.
begin
(),
a_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
j
==
1
)
?
0
:
i
;
});
std
::
transform
(
out_idx
.
begin
()
+
b_len_diff
,
out_idx
.
end
(),
b_batch_lens
.
begin
(),
b_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
j
==
1
)
?
0
:
i
;
});
std
::
size_t
a_ind
=
a_batch_shape
.
index
(
a_idx
.
begin
(),
b_idx
.
end
());
std
::
size_t
b_ind
=
b_batch_shape
.
index
(
b_idx
.
begin
(),
b_idx
.
end
());
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
generic_rocblas_batched_gemm
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha_r
,
to_pointer
(
args
[
1
],
k
*
n
*
num_matrices
*
b_ind
),
ldb
,
k
*
n
,
to_pointer
(
args
[
0
],
m
*
k
*
num_matrices
*
a_ind
),
lda
,
m
*
k
,
&
beta_r
,
to_pointer
(
args
[
2
],
m
*
n
*
num_matrices
*
out_ind
),
ldc
,
m
*
n
,
num_matrices
);
});
});
}
return
(
is_3inputs
?
args
[
3
]
:
args
[
2
]
)
;
return
args
[
2
];
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
3ccd7e15
...
@@ -20,9 +20,7 @@ struct miopen_gemm
...
@@ -20,9 +20,7 @@ struct miopen_gemm
int
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
int
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
private:
private:
std
::
size_t
compute_offset
(
std
::
vector
<
std
::
size_t
>&
out_lens
,
void
fill_result
(
context
&
ctx
,
const
shape
&
output_shape
,
const
argument
&
result
,
const
argument
&
c
)
const
;
std
::
size_t
index
,
std
::
vector
<
std
::
size_t
>&
data_lens
)
const
;
};
};
}
// namespace gpu
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment