Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
236a2030
Unverified
Commit
236a2030
authored
Jun 11, 2024
by
Keshav Balasubramanian
Committed by
GitHub
Jun 11, 2024
Browse files
Value initialize packing descriptors (#912)
Signed-off-by:
Keshav
<
keshavb@nvidia.com
>
parent
c054c06f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
transformer_engine/jax/csrc/extensions/packing.cpp
transformer_engine/jax/csrc/extensions/packing.cpp
+3
-3
No files found.
transformer_engine/jax/csrc/extensions/packing.cpp
View file @
236a2030
...
@@ -12,7 +12,7 @@ namespace jax {
...
@@ -12,7 +12,7 @@ namespace jax {
pybind11
::
bytes
PackCustomCallCommonDescriptor
(
const
std
::
vector
<
size_t
>
&
shape
,
DType
in_dtype
,
pybind11
::
bytes
PackCustomCallCommonDescriptor
(
const
std
::
vector
<
size_t
>
&
shape
,
DType
in_dtype
,
DType
out_dtype
,
size_t
act_enum
)
{
DType
out_dtype
,
size_t
act_enum
)
{
CustomCallCommonDescriptor
desc
;
CustomCallCommonDescriptor
desc
{}
;
desc
.
shape
.
from_vector
(
shape
);
desc
.
shape
.
from_vector
(
shape
);
desc
.
in_dtype
=
in_dtype
;
desc
.
in_dtype
=
in_dtype
;
desc
.
out_dtype
=
out_dtype
;
desc
.
out_dtype
=
out_dtype
;
...
@@ -24,7 +24,7 @@ pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shap
...
@@ -24,7 +24,7 @@ pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shap
const
std
::
vector
<
size_t
>
&
wkshape
,
DType
in_dtype
,
const
std
::
vector
<
size_t
>
&
wkshape
,
DType
in_dtype
,
DType
out_dtype
,
DType
wk_dtype
,
DType
out_dtype
,
DType
wk_dtype
,
size_t
act_enum
)
{
size_t
act_enum
)
{
CustomCallCommonWkDescriptor
desc
;
CustomCallCommonWkDescriptor
desc
{}
;
desc
.
shape
.
from_vector
(
shape
);
desc
.
shape
.
from_vector
(
shape
);
desc
.
wkshape
.
from_vector
(
wkshape
);
desc
.
wkshape
.
from_vector
(
wkshape
);
desc
.
in_dtype
=
in_dtype
;
desc
.
in_dtype
=
in_dtype
;
...
@@ -39,7 +39,7 @@ pybind11::bytes PackCustomCallNormDescriptor(
...
@@ -39,7 +39,7 @@ pybind11::bytes PackCustomCallNormDescriptor(
const
std
::
vector
<
size_t
>
&
dgamma_part_shape
,
const
std
::
vector
<
size_t
>
&
dbeta_part_shape
,
const
std
::
vector
<
size_t
>
&
dgamma_part_shape
,
const
std
::
vector
<
size_t
>
&
dbeta_part_shape
,
DType
x_dtype
,
DType
w_dtype
,
DType
wkspace_dtype
,
DType
barrier_dtype
,
DType
dgamma_part_dtype
,
DType
x_dtype
,
DType
w_dtype
,
DType
wkspace_dtype
,
DType
barrier_dtype
,
DType
dgamma_part_dtype
,
DType
dbeta_part_dtype
,
bool
zero_centered_gamma
,
float
eps
,
int
sm_margin
)
{
DType
dbeta_part_dtype
,
bool
zero_centered_gamma
,
float
eps
,
int
sm_margin
)
{
CustomCallNormDescriptor
desc
;
CustomCallNormDescriptor
desc
{}
;
desc
.
batch_size
=
batch_size
;
desc
.
batch_size
=
batch_size
;
desc
.
hidden_size
=
hidden_size
;
desc
.
hidden_size
=
hidden_size
;
desc
.
wkspace_size
=
wkspace_size
;
desc
.
wkspace_size
=
wkspace_size
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment