Unverified Commit 170203ae authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Performance] Creating out buffers for `segment_mm`|`sddmm` via `torch.empty()` (#5462)

* update for segmentMM

* update for sddmm

* fix a bug
parent 0ac9bf34
......@@ -549,7 +549,7 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target="u", rhs_target="v"):
out_shp = (gidx.number_of_edges(0),) + infer_broadcast_shape(
op, lhs_shp[1:], rhs_shp[1:]
)
out = F.zeros(out_shp, dtype, ctx)
out = F.empty(out_shp, dtype, ctx)
if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSDDMM(
gidx,
......@@ -615,7 +615,7 @@ def _gsddmm_hetero(
out_shp = (gidx.number_of_edges(etid),) + infer_broadcast_shape(
op, lhs_shp[1:], rhs_shp[1:]
)
out_list[etid] = F.zeros(out_shp, dtype, ctx)
out_list[etid] = F.empty(out_shp, dtype, ctx)
if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSDDMMHetero(
gidx,
......
......@@ -967,6 +967,26 @@ def swapaxes(input, axis1, axis2):
pass
def empty(shape, dtype, ctx):
"""Create a tensor filled with uninitialized data.
Parameters
----------
shape : tuple of int
The tensor shape.
dtype : data type
It should be one of the values in the data type dict.
ctx : context
The device of the result tensor.
Returns
-------
Tensor
The emtpy tensor.
"""
pass
def zeros(shape, dtype, ctx):
"""Create a zero tensor.
......
......@@ -347,6 +347,10 @@ def swapaxes(input, axis1, axis2):
return nd.swapaxes(input, axis1, axis2)
def empty(shape, dtype, ctx):
return nd.empty(shape, dtype=dtype, ctx=ctx)
def zeros(shape, dtype, ctx):
return nd.zeros(shape, dtype=dtype, ctx=ctx)
......
......@@ -970,7 +970,7 @@ class SEGMENTMM(th.autograd.Function):
def forward(ctx, A, B, seglen_A):
if B.dim() != 3:
raise ValueError("segment_mm expects B to be a 3D tensor.")
C = th.zeros((A.shape[0], B.shape[2]), device=A.device, dtype=A.dtype)
C = th.empty((A.shape[0], B.shape[2]), device=A.device, dtype=A.dtype)
C = _segment_mm(A, B, C, seglen_A)
ctx.backward_cache = A, B, seglen_A
return C
......@@ -981,11 +981,11 @@ class SEGMENTMM(th.autograd.Function):
A_grad = B_grad = None
if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T
A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype)
A_grad = th.empty(A.shape, device=A.device, dtype=A.dtype)
A_grad = _segment_mm(dZ, B, A_grad, seglen_A, b_trans=True)
if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = th.empty(B.shape, device=B.device, dtype=B.dtype)
B_grad = _segment_mm_backward_B(A, dZ, B_grad, seglen_A)
return A_grad, B_grad, None
......
......@@ -279,6 +279,10 @@ def swapaxes(input, axis1, axis2):
return th.transpose(input, axis1, axis2)
def empty(shape, dtype, ctx):
return th.empty(shape, dtype=dtype, device=ctx)
def zeros(shape, dtype, ctx):
return th.zeros(shape, dtype=dtype, device=ctx)
......
......@@ -336,6 +336,11 @@ def swapaxes(input, axis1, axis2):
return tf.transpose(input, perm=t)
def empty(shape, dtype, ctx):
# tf doesn't have tf.empty(), use zeros() as a workaround
return zeros(shape, dtype, ctx)
def zeros(shape, dtype, ctx):
with tf.device(ctx):
t = tf.zeros(shape, dtype=dtype)
......
......@@ -256,7 +256,7 @@ void SegmentMMBackwardB(
int64_t A_offset = 0, dC_offset = 0, dB_offset = 0;
int64_t m, n, k;
int64_t num_rel = seglen.NumElements();
DType alpha = 1., beta = 1.;
DType alpha = 1., beta = 0.;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->cublas_handle)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment