Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2899f9a5
Commit
2899f9a5
authored
Feb 22, 2019
by
Shucai Xiao
Browse files
enhance gather to support scalar as input indices.
parent
c50e8004
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
154 additions
and
36 deletions
+154
-36
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+59
-24
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+37
-1
src/targets/gpu/device/gather.cpp
src/targets/gpu/device/gather.cpp
+58
-11
No files found.
src/include/migraphx/operators.hpp
View file @
2899f9a5
...
@@ -758,27 +758,38 @@ struct gather
...
@@ -758,27 +758,38 @@ struct gather
int
axis_index
=
(
axis
<
0
)
?
(
n_dim
+
axis
)
:
axis
;
int
axis_index
=
(
axis
<
0
)
?
(
n_dim
+
axis
)
:
axis
;
auto
type
=
inputs
[
0
].
type
();
auto
type
=
inputs
[
0
].
type
();
lens
[
axis_index
]
=
inputs
[
1
].
elements
();
lens
.
erase
(
lens
.
begin
()
+
axis_index
);
if
(
!
inputs
[
1
].
scalar
())
return
{
type
,
lens
};
{
auto
ind_lens
=
inputs
[
1
].
lens
();
lens
.
insert
(
lens
.
begin
()
+
axis_index
,
ind_lens
.
begin
(),
ind_lens
.
end
());
}
}
template
<
class
T
>
// for scalar output
void
compute_index
(
const
T
&
out_idx
,
if
(
lens
.
size
()
==
0
)
const
int
axis_index
,
const
std
::
vector
<
std
::
size_t
>&
vec_indices
,
const
std
::
size_t
max_dim
,
T
&
in_idx
)
const
{
in_idx
=
out_idx
;
std
::
size_t
idx
=
vec_indices
.
at
(
out_idx
[
axis_index
]);
if
(
idx
>=
max_dim
)
{
{
MIGRAPHX_THROW
(
"Gather: indices are out of range in input tensor"
)
;
return
{
type
,
{
1
},
{
0
}}
;
}
}
in_idx
[
axis_index
]
=
idx
;
return
{
type
,
lens
};
}
}
// template <class T>
// void compute_index(const T& out_idx,
// const int axis_index,
// const std::vector<std::size_t>& vec_indices,
// const std::size_t max_dim,
// T& in_idx) const
// {
// in_idx = out_idx;
// std::size_t idx = vec_indices.at(out_idx[axis_index]);
// if(idx >= max_dim)
// {
// MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
// }
// in_idx[axis_index] = idx;
// }
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
...
@@ -786,14 +797,38 @@ struct gather
...
@@ -786,14 +797,38 @@ struct gather
int
axis_index
=
(
axis
<
0
)
?
(
output_shape
.
lens
().
size
()
+
axis
)
:
axis
;
int
axis_index
=
(
axis
<
0
)
?
(
output_shape
.
lens
().
size
()
+
axis
)
:
axis
;
// max dimension in axis
// max dimension in axis
std
::
size_t
max_dim
=
args
[
0
].
get_shape
().
lens
()[
axis_index
];
//std::size_t max_dim = args[0].get_shape().lens()[axis_index];
std
::
vector
<
std
::
size_t
>
vec_indices
;
//std::vector<std::size_t> vec_indices;
args
[
1
].
visit
([
&
](
auto
indices
)
{
vec_indices
.
assign
(
indices
.
begin
(),
indices
.
end
());
});
//args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
std
::
vector
<
std
::
size_t
>
in_idx
;
args
[
1
].
visit
([
&
]
(
auto
indices
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
if
(
indices
.
get_shape
().
scalar
())
this
->
compute_index
(
idx
,
axis_index
,
vec_indices
,
max_dim
,
in_idx
);
{
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
in_idx
.
begin
(),
in_idx
.
end
());
if
(
output_shape
.
scalar
())
{
output
[
0
]
=
data
[
indices
.
front
()];
}
else
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
out_idx
)
{
auto
data_idx
=
out_idx
;
data_idx
.
insert
(
data_idx
.
begin
()
+
axis_index
,
indices
.
front
());
output
(
out_idx
.
begin
(),
out_idx
.
end
())
=
data
(
data_idx
.
begin
(),
data_idx
.
end
());
});
}
}
else
{
auto
ind_lens
=
indices
.
get_shape
().
lens
();
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
out_idx
)
{
auto
data_idx
=
out_idx
;
auto
start_it
=
data_idx
.
begin
()
+
axis_index
;
auto
end_it
=
data_idx
.
begin
()
+
axis_index
+
ind_lens
.
size
();
std
::
vector
<
std
::
size_t
>
ind_idx
(
start_it
,
end_it
);
data_idx
.
erase
(
start_it
,
end_it
);
data_idx
.
insert
(
start_it
,
indices
(
ind_idx
.
begin
(),
ind_idx
.
end
()));
output
(
out_idx
.
begin
(),
out_idx
.
end
())
=
data
(
data_idx
.
begin
(),
data_idx
.
end
());
});
}
});
});
});
});
...
...
src/onnx/onnx.cpp
View file @
2899f9a5
...
@@ -434,6 +434,14 @@ struct onnx_parser
...
@@ -434,6 +434,14 @@ struct onnx_parser
const
std
::
vector
<
instruction_ref
>&
)
const
std
::
vector
<
instruction_ref
>&
)
{
{
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
migraphx
::
shape
v_shape
=
v
.
get_shape
();
// for constant containing 1 element, consider it as a scalar
if
(
v_shape
.
elements
()
==
1
)
{
migraphx
::
shape
scalar_shape
{
v_shape
.
type
(),
{
1
},
{
0
}};
return
prog
.
add_literal
(
migraphx
::
literal
{
scalar_shape
,
v
.
data
()});
}
return
prog
.
add_literal
(
v
);
return
prog
.
add_literal
(
v
);
}
}
...
@@ -460,6 +468,18 @@ struct onnx_parser
...
@@ -460,6 +468,18 @@ struct onnx_parser
{
{
transb
=
parse_value
(
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
transb
=
parse_value
(
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
}
}
// beginning or end of both args have dimension 1, need to squeeze
// before calling gemm, then doing unsqueeze after getting results
std
::
size_t
num_squeeze
=
args
[
0
]
->
get_shape
().
lens
().
size
();
if
(
num_squeeze
>
2
)
{
std
::
vector
<
int64_t
>
vec_axises
(
num_squeeze
-
2
);
std
::
iota
(
vec_axises
.
begin
(),
vec_axises
.
end
(),
0
);
args
[
0
]
=
prog
.
add_instruction
(
op
::
squeeze
{
vec_axises
},
args
[
0
]);
args
[
1
]
=
prog
.
add_instruction
(
op
::
squeeze
{
vec_axises
},
args
[
1
]);
}
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
auto
l1
=
(
transa
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
0
])
:
args
[
0
];
auto
l1
=
(
transa
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
0
])
:
args
[
0
];
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
...
@@ -468,6 +488,13 @@ struct onnx_parser
...
@@ -468,6 +488,13 @@ struct onnx_parser
if
(
beta
!=
0.
f
)
if
(
beta
!=
0.
f
)
{
{
auto
l3
=
prog
.
add_instruction
(
op
::
dot
{
alpha
},
l1
,
l2
);
auto
l3
=
prog
.
add_instruction
(
op
::
dot
{
alpha
},
l1
,
l2
);
if
(
num_squeeze
>
2
)
{
std
::
vector
<
int64_t
>
vec_axises
(
num_squeeze
-
2
);
std
::
iota
(
vec_axises
.
begin
(),
vec_axises
.
end
(),
0
);
l3
=
prog
.
add_instruction
(
op
::
unsqueeze
{
vec_axises
},
l3
);
}
auto
l4
=
args
[
2
];
auto
l4
=
args
[
2
];
if
(
l4
->
get_shape
().
scalar
())
// ignore args[2] (no C value added to alpha*A*B)
if
(
l4
->
get_shape
().
scalar
())
// ignore args[2] (no C value added to alpha*A*B)
return
l3
;
return
l3
;
...
@@ -480,7 +507,16 @@ struct onnx_parser
...
@@ -480,7 +507,16 @@ struct onnx_parser
return
add_broadcastable_binary_op
(
l3
,
l4
,
op
::
add
{});
return
add_broadcastable_binary_op
(
l3
,
l4
,
op
::
add
{});
}
}
}
}
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
);
auto
dot_res
=
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
);
if
(
num_squeeze
>
2
)
{
std
::
vector
<
int64_t
>
vec_axises
(
num_squeeze
-
2
);
std
::
iota
(
vec_axises
.
begin
(),
vec_axises
.
end
(),
0
);
dot_res
=
prog
.
add_instruction
(
op
::
unsqueeze
{
vec_axises
},
dot_res
);
}
return
dot_res
;
}
}
instruction_ref
instruction_ref
...
...
src/targets/gpu/device/gather.cpp
View file @
2899f9a5
...
@@ -20,18 +20,65 @@ argument gather(hipStream_t stream,
...
@@ -20,18 +20,65 @@ argument gather(hipStream_t stream,
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
std
::
size_t
nelements
=
output_shape
.
elements
();
std
::
size_t
nelements
=
output_shape
.
elements
();
args
[
1
].
visit
([
&
](
auto
indices
)
{
args
[
1
].
visit
([
&
](
auto
indices
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
auto
*
outptr
=
device_cast
(
output
.
data
());
auto
*
outptr
=
device_cast
(
output
.
data
());
const
auto
*
inptr
=
device_cast
(
input
.
data
());
const
auto
*
inptr
=
device_cast
(
input
.
data
());
hip_tensor_descriptor
<
ndim
>
desc_input
(
input
.
get_shape
());
if
(
output_shape
.
scalar
())
hip_tensor_descriptor
<
ndim
>
desc_output
(
output
.
get_shape
());
{
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
gs_launch
(
stream
,
1
)([
=
](
auto
i
)
{
auto
lens
=
desc_output
.
multi
(
i
);
outptr
[
i
]
=
inptr
[
indices_ptr
[
0
]];
lens
[
axis_index
]
=
indices_ptr
[
lens
[
axis_index
]];
outptr
[
i
]
=
inptr
[
desc_input
.
linear
(
lens
)];
});
});
}
else
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
n_out_dim
)
{
visit_tensor_size
(
args
[
0
].
get_shape
().
lens
().
size
(),
[
&
](
auto
n_in_dim
)
{
hip_tensor_descriptor
<
n_in_dim
>
desc_input
(
input
.
get_shape
());
hip_tensor_descriptor
<
n_out_dim
>
desc_output
(
output
.
get_shape
());
if
(
args
[
1
].
get_shape
().
scalar
())
{
gs_launch
(
stream
,
nelements
)([
=
](
auto
ii
)
{
auto
out_idx
=
desc_output
.
multi
(
ii
);
auto
in_idx
=
desc_input
.
multi
(
0
);
for
(
int
i
=
0
;
i
<
axis_index
;
++
i
)
{
in_idx
[
i
]
=
out_idx
[
i
];
}
in_idx
[
axis_index
]
=
indices_ptr
[
0
];
for
(
int
i
=
axis_index
+
1
;
i
<
n_in_dim
;
++
i
)
{
in_idx
[
i
]
=
out_idx
[
i
-
1
];
}
outptr
[
ii
]
=
inptr
[
desc_input
.
linear
(
in_idx
)];
});
});
}
else
{
visit_tensor_size
(
args
[
1
].
get_shape
().
lens
().
size
(),
[
&
](
auto
n_ind_dim
)
{
hip_tensor_descriptor
<
n_ind_dim
>
desc_ind
(
args
[
1
].
get_shape
());
gs_launch
(
stream
,
nelements
)([
=
](
auto
ii
)
{
auto
out_idx
=
desc_output
.
multi
(
ii
);
auto
in_idx
=
desc_input
.
multi
(
0
);
for
(
int
i
=
0
;
i
<
axis_index
;
++
i
)
{
in_idx
[
i
]
=
out_idx
[
i
];
}
auto
ind_idx
=
desc_ind
.
multi
(
0
);
for
(
int
i
=
0
;
i
<
n_ind_dim
;
++
i
)
{
ind_idx
[
i
]
=
out_idx
[
i
+
axis_index
];
}
in_idx
[
axis_index
]
=
indices_ptr
[
desc_ind
.
linear
(
ind_idx
)];
for
(
int
i
=
axis_index
+
1
;
i
<
n_in_dim
;
++
i
)
{
in_idx
[
i
]
=
out_idx
[
i
+
n_ind_dim
-
1
];
}
outptr
[
ii
]
=
inptr
[
desc_input
.
linear
(
in_idx
)];
});
});
}
});
});
}
});
});
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment