Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
26c33a16
".github/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "b7d385e812814e433dd45ea342c03b42034d9123"
Commit
26c33a16
authored
Nov 07, 2018
by
Scott Thornton
Browse files
Added multibroadcast + test
parent
2946e34e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
0 deletions
+81
-0
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+47
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+34
-0
No files found.
src/include/migraph/operators.hpp
View file @
26c33a16
...
...
@@ -759,6 +759,53 @@ struct broadcast
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
struct
multibroadcast
{
std
::
vector
<
std
::
size_t
>
output_lens
;
std
::
string
name
()
const
{
return
"multibroadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
size
()
<=
0
)
MIGRAPH_THROW
(
"inputs dimensions should be > 0"
);
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
MIGRAPH_THROW
(
"inputs dimensions should <= output size"
);
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
auto
extra
=
output_lens
.
size
()
-
input
.
lens
().
size
();
if
(
input
.
lens
().
size
()
<
output_lens
.
size
())
{
for
(
std
::
size_t
i
=
output_lens
.
size
()
-
1
;
i
>
0
;
i
--
)
{
if
(
output_lens
[
i
]
==
input
.
lens
()[
i
-
extra
])
{
bcast_strides
[
i
]
=
input
.
strides
()[
i
-
extra
];
}
}
}
else
{
for
(
std
::
size_t
i
=
0
;
i
<
input
.
lens
().
size
();
i
++
)
{
if
(
output_lens
[
i
]
==
input
.
lens
()[
i
])
{
bcast_strides
[
i
]
=
input
.
strides
()[
i
];
}
}
}
return
{
t
,
output_lens
,
bcast_strides
};
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
struct
scalar
{
shape
scalar_bcast
;
...
...
test/op_shape_test.cpp
View file @
26c33a16
...
...
@@ -145,8 +145,42 @@ void slice_shape()
migraph
::
op
::
slice
{{
2
},
{
2
},
{
10
}},
input
);
}
void
multibroadcast_shape
()
{
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
2
,
5
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
2
,
1
,
3
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
lens
,
{
0
,
3
,
0
,
1
}},
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
2
,
5
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
2
,
1
,
1
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
lens
,
{
0
,
1
,
0
,
0
}},
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
1
,
1
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
4
,
1
,
1
,
1
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
lens
,
{
1
,
1
,
1
,
0
}},
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
1
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
4
,
1
,
1
,
1
}};
throws_shape
(
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
{
std
::
vector
<
std
::
size_t
>
lens
{
4
,
1
,
3
};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{}};
throws_shape
(
migraph
::
op
::
multibroadcast
{
lens
},
input
);
}
}
int
main
()
{
multibroadcast_shape
();
batch_norm_inference_shape
();
convolution_shape
();
transpose_shape
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment