Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
3b1f30a9
"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "7bb2af355bb2332a9c2152c6e155cbfc089f9660"
Commit
3b1f30a9
authored
Jul 11, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.4' into w8a8_dev_v2.4
parents
6a20ff90
76023d21
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
310 additions
and
22 deletions
+310
-22
tests/pytorch/test_int8_channelwise_gemm_exact.py
tests/pytorch/test_int8_channelwise_gemm_exact.py
+55
-1
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+89
-19
transformer_engine/common/gemm/hipblas_gemm.cu
transformer_engine/common/gemm/hipblas_gemm.cu
+12
-2
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+5
-0
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+8
-0
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+132
-0
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+7
-0
transformer_engine/pytorch/triton/per_token_group_quant.py
transformer_engine/pytorch/triton/per_token_group_quant.py
+2
-0
No files found.
tests/pytorch/test_int8_channelwise_gemm_exact.py
View file @
3b1f30a9
...
...
@@ -503,4 +503,58 @@ for i in range(20):
dw
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
\ No newline at end of file
end
=
time
.
time
()
# bacth gemm wgrad
m
=
32
k
=
32
n
=
32
b
=
4
transa
=
False
transb
=
True
dy_int8
=
(
torch
.
randn
((
b
,
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
int8
)
x_int8
=
(
torch
.
randn
((
b
,
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
int8
)
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
int32_dw_list
=
[]
for
i
in
range
(
b
):
int32_dw
=
torch
.
_int_mm
(
dy_int8
[
i
].
t
(),
x_int8
[
i
])
# bf16_dw = torch.matmul(dy_int8[i].t(), x_int8[i])
int32_dw_list
.
append
(
int32_dw
)
batched_int32_dw
=
torch
.
stack
(
int32_dw_list
)
# print("batched_int32_dw.shape: ", batched_int32_dw.shape)
# print("batched_int32_dw: ", batched_int32_dw)
out_dtype
=
torch
.
int32
out
=
torch
.
empty
((
b
,
n
,
k
),
dtype
=
out_dtype
,
device
=
device
)
te_dw
=
tex
.
generic_batchgemm
(
x_int8
.
view
(
-
1
,
x_int8
.
size
(
-
1
)),
transa
,
dy_int8
.
view
(
-
1
,
dy_int8
.
size
(
-
1
)),
transb
,
out
.
view
(
-
1
,
out
.
size
(
-
1
)),
b
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# print("te_dw.shape: ", te_dw.view(b, -1, te_dw.size(-1)).shape)
# print("te_dw: ", te_dw.view(b, -1, te_dw.size(-1)))
torch
.
testing
.
assert_close
(
te_dw
,
batched_int32_dw
,
atol
=
1e-5
,
rtol
=
1e-5
)
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
3b1f30a9
...
...
@@ -986,24 +986,94 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
}
else
{
// TT
NVTE_ERROR
(
"TT layout not allowed."
);
}
hipblas_batchgemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
batch_count
,
stream
);
hipblas_batchgemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
batch_count
,
stream
);
}
// add for batchgemm
void
nvte_cublas_batchgemm_v2
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublas_batchgemm_v2
);
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
Tensor
*
outputD
=
convertNVTETensor
(
D
);
const
Tensor
*
biasTensor
=
convertNVTETensor
(
bias
);
Tensor
*
outputGelu
=
convertNVTETensor
(
pre_gelu_out
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
))
{
NVTE_ERROR
(
"MOE batchgemm not surpport bias or gelu."
);
}
int
m
,
n
,
k
;
if
(
!
transa
&&
transb
)
{
// for NT
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
else
if
(
transa
&&
!
transb
){
// for TN
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
else
if
(
!
transa
&&
!
transb
){
// for NN
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
int
lda
,
ldb
,
ldd
;
if
(
transa
&&
!
transb
)
{
// TN
lda
=
k
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
!
transb
)
{
// NN
lda
=
m
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
transb
)
{
// NT
lda
=
m
;
ldb
=
n
;
ldd
=
m
;
}
else
{
// TT
NVTE_ERROR
(
"TT layout not allowed."
);
}
hipblas_batchgemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
batch_count
,
stream
);
}
#endif
\ No newline at end of file
transformer_engine/common/gemm/hipblas_gemm.cu
View file @
3b1f30a9
...
...
@@ -182,6 +182,16 @@ void hipblas_batchgemm(const Tensor *inputA,
float
one
=
1.0
f
;
float
zero
=
0.0
f
;
float
beta
=
accumulate
?
one
:
zero
;
int
int_one
=
1
;
int
int_zero
=
0
;
int
int_beta
=
int_zero
;
bool
use_int8
=
false
;
if
((
A_type
==
HIPBLAS_R_8I
)
&&
(
B_type
==
HIPBLAS_R_8I
)
&&
(
D_type
==
HIPBLAS_R_32I
))
{
NVTE_CHECK
(
!
accumulate
,
"Int8 gemm not support accumulate."
);
use_int8
=
true
;
computeType
=
HIPBLAS_R_32I
;
}
hipblasSetStream
(
handle
,
stream
);
// execute multiply
...
...
@@ -197,7 +207,7 @@ void hipblas_batchgemm(const Tensor *inputA,
m
,
n
,
k
,
static_cast
<
const
void
*>
(
&
one
),
use_int8
?
static_cast
<
const
void
*>
(
&
int_one
)
:
static_cast
<
const
void
*>
(
&
one
),
A
,
A_type
,
lda
,
...
...
@@ -206,7 +216,7 @@ void hipblas_batchgemm(const Tensor *inputA,
B_type
,
ldb
,
strideB
,
static_cast
<
const
void
*>
(
&
beta
),
use_int8
?
static_cast
<
const
void
*>
(
&
int_beta
)
:
static_cast
<
const
void
*>
(
&
beta
),
D
,
D_type
,
ldd
,
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
3b1f30a9
...
...
@@ -122,6 +122,11 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
void
nvte_cublas_batchgemm_v2
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
#endif
#ifdef __cplusplus
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
3b1f30a9
...
...
@@ -90,6 +90,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
);
std
::
vector
<
py
::
object
>
generic_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
=
nullptr
,
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
);
void
te_atomic_gemm
(
at
::
Tensor
A
,
at
::
Tensor
A_scale_inverse
,
DType
A_type
,
std
::
vector
<
int64_t
>
A_scaling_mode
,
bool
transa
,
at
::
Tensor
B
,
at
::
Tensor
B_scale_inverse
,
DType
B_type
,
std
::
vector
<
int64_t
>
B_scaling_mode
,
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
3b1f30a9
...
...
@@ -271,6 +271,138 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
return
out
;
}
std
::
vector
<
py
::
object
>
generic_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
,
std
::
optional
<
CommOverlapType
>
comm_type
,
MaybeTensor
extra_output
,
bool
bulk_overlap
)
{
// Input tensors
NVTE_CHECK
(
!
A
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
B
.
is_none
(),
"Tensor B has not been provided"
);
auto
none
=
py
::
none
();
TensorWrapper
A_tensor
=
makeTransformerEngineTensor
(
A
,
none
);
TensorWrapper
B_tensor
=
makeTransformerEngineTensor
(
B
,
none
);
const
bool
low_precision
=
detail
::
is_low_precision
(
A_tensor
.
dtype
())
||
detail
::
is_low_precision
(
B_tensor
.
dtype
());
// Check tensor dimensions
const
auto
&
A_shape
=
A_tensor
.
shape
();
const
auto
&
B_shape
=
B_tensor
.
shape
();
const
auto
&
D_shape
=
detail
::
getGemmOutputShape
(
A_shape
,
transa
,
B_shape
,
transb
);
NVTE_CHECK
(
A_shape
.
ndim
>=
1
,
"Tensor A needs to have at least 1 dimension"
);
NVTE_CHECK
(
B_shape
.
ndim
>=
1
,
"Tensor B needs to have at least 1 dimension"
);
// Output tensor
TensorWrapper
D_tensor
;
if
(
D
.
is_none
())
{
NVTE_ERROR
(
"generic batchgemm D must be not None."
);
}
else
{
D_tensor
=
makeTransformerEngineTensor
(
D
,
quantizer
);
if
(
out_dtype
)
{
NVTE_CHECK
(
*
out_dtype
==
D_tensor
.
dtype
(),
"GEMM output has invalid dtype (expected "
,
static_cast
<
int
>
(
*
out_dtype
),
", found "
,
static_cast
<
int
>
(
D_tensor
.
dtype
()),
")"
);
}
}
// Bias tensor
TensorWrapper
bias_tensor
;
MaybeTensor
bias_grad
=
std
::
nullopt
;
if
(
bias
.
has_value
())
{
if
(
grad
)
{
auto
opts
=
torch
::
TensorOptions
().
dtype
(
GetATenDType
(
D_tensor
.
dtype
())).
device
(
torch
::
kCUDA
);
bias_grad
=
at
::
empty
({
static_cast
<
int64_t
>
(
B_shape
.
data
[
B_shape
.
ndim
-
1
])},
opts
);
bias_tensor
=
makeTransformerEngineTensor
(
*
bias_grad
);
}
else
{
if
(
!
bias
->
is_contiguous
())
{
bias
=
bias
->
contiguous
();
}
bias_tensor
=
makeTransformerEngineTensor
(
*
bias
);
}
}
// Activation input tensor
MaybeTensor
pre_gelu_out
=
std
::
nullopt
;
DType
gelu_type
=
low_precision
?
bias_type
:
D_tensor
.
dtype
();
if
(
gelu
)
{
if
(
!
grad
)
{
auto
dtype
=
GetATenDType
(
gelu_type
);
auto
opts
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
std
::
vector
<
int64_t
>
torch_shape
;
for
(
auto
v
:
D_shape
)
{
torch_shape
.
push_back
(
v
);
}
pre_gelu_out
=
at
::
empty
(
torch_shape
,
opts
);
}
else
{
if
(
gelu_in
.
has_value
())
{
pre_gelu_out
=
*
gelu_in
;
}
}
}
const
auto
gelu_shape
=
gelu
?
D_shape
:
std
::
vector
<
size_t
>
{
0
};
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
get_data_ptr
(
pre_gelu_out
),
gelu_shape
,
gelu_type
);
// Workspace
auto
te_workspace
=
makeTransformerEngineTensor
(
workspace
.
data_ptr
(),
std
::
vector
<
size_t
>
{
workspaceSize
},
DType
::
kByte
);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const
int
device_id
=
at
::
cuda
::
current_device
();
const
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
(
device_id
);
int
num_math_sms
=
sm_count
-
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
sm_count
);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
main_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
A_tensor
.
numel
()
!=
0
&&
B_tensor
.
numel
()
!=
0
)
{
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
A_tensor
,
transa
)));
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
B_tensor
,
!
transb
)));
if
(
comm_overlap
)
{
NVTE_ERROR
(
"generic batchgemm not surpport comm_overlap."
);
}
else
{
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_batchgemm_v2
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
});
}
}
else
{
if
(
D_tensor
.
numel
()
!=
0
&&
!
accumulate
)
{
D_tensor
.
zero_
(
main_stream
);
}
if
(
bias
.
has_value
())
{
if
(
bias
->
numel
()
!=
0
&&
grad
)
{
bias_grad
->
zero_
();
}
}
}
// Pack outputs
std
::
vector
<
py
::
object
>
out
;
out
.
emplace_back
(
std
::
move
(
D
));
out
.
emplace_back
(
py
::
cast
(
bias_grad
));
if
(
gelu
&&
!
grad
)
{
out
.
emplace_back
(
py
::
cast
(
*
pre_gelu_out
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
if
(
extra_output
.
has_value
())
{
out
.
emplace_back
(
py
::
cast
(
extra_output
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
return
out
;
}
void
te_atomic_gemm
(
at
::
Tensor
A
,
at
::
Tensor
A_scale_inverse
,
DType
A_type
,
std
::
vector
<
int64_t
>
A_scaling_mode
,
bool
transa
,
at
::
Tensor
B
,
at
::
Tensor
B_scale_inverse
,
DType
B_type
,
std
::
vector
<
int64_t
>
B_scaling_mode
,
...
...
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
3b1f30a9
...
...
@@ -110,6 +110,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"workspace_size"
),
py
::
arg
(
"accumulate"
),
py
::
arg
(
"use_split_accumulator"
),
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
);
m
.
def
(
"generic_batchgemm"
,
transformer_engine
::
pytorch
::
generic_batchgemm
,
"Compute Batched GEMM (matrix-matrix multiply)"
,
py
::
arg
(
"A"
),
py
::
arg
(
"transA"
),
py
::
arg
(
"B"
),
py
::
arg
(
"transB"
),
py
::
arg
(
"D"
),
py
::
arg
(
"batchcount"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"output_dtype"
),
py
::
arg
(
"bias"
),
py
::
arg
(
"bias_type"
),
py
::
arg
(
"gelu"
),
py
::
arg
(
"gelu_in"
),
py
::
arg
(
"grad"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"workspace_size"
),
py
::
arg
(
"accumulate"
),
py
::
arg
(
"use_split_accumulator"
),
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
);
m
.
def
(
"gelu"
,
transformer_engine
::
pytorch
::
gelu
,
"GeLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"relu"
,
transformer_engine
::
pytorch
::
relu
,
"ReLU activation"
,
py
::
arg
(
"input"
),
...
...
transformer_engine/pytorch/triton/per_token_group_quant.py
View file @
3b1f30a9
...
...
@@ -153,6 +153,8 @@ def _per_token_quant_fp8_to_int8(
def
per_token_quant_fp8_to_int8
(
x
,
fp8_scale_inv
,
inplace
=
False
):
assert
x
.
is_contiguous
()
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
M
=
x
.
numel
()
//
x
.
shape
[
-
1
]
N
=
x
.
shape
[
-
1
]
if
inplace
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment