"sgl-kernel/git@developer.sourcefind.cn:change/sglang.git" did not exist on "8fdcd98efefcd98a804c412d16a95e6f59f17f54"
Commit 907ca927 authored by Myle Ott's avatar Myle Ott
Browse files

Better support for torch.no_grad (since volatile is deprecated)

parent 0b84ab19
......@@ -69,7 +69,8 @@ class LinearizedConvolution(ConvTBC):
# append next input
self.input_buffer[:, -1, :] = input[:, -1, :]
input = utils.volatile_variable(self.input_buffer)
output = F.linear(input.view(bsz, -1), weight, self.bias)
with utils.maybe_no_grad():
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1)
def clear_incremental_state(self):
......
......@@ -71,8 +71,9 @@ class SequenceGenerator(object):
srclen = input['src_tokens'].size(1)
if timer is not None:
timer.start()
hypos = self.generate(input['src_tokens'], beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b))
with utils.maybe_no_grad():
hypos = self.generate(input['src_tokens'], beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None:
timer.stop(s['ntokens'])
for i, id in enumerate(s['id'].data):
......@@ -327,7 +328,8 @@ class SequenceGenerator(object):
avg_probs = None
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
decoder_out, attn = model.decoder(tokens, encoder_out)
with utils.maybe_no_grad():
decoder_out, attn = model.decoder(tokens, encoder_out)
probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data
if avg_probs is None:
avg_probs = probs
......
......@@ -176,7 +176,7 @@ def _upgrade_args(args):
return args
def maybe_no_grad(condition):
def maybe_no_grad(condition=True):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
......@@ -185,9 +185,10 @@ def maybe_no_grad(condition):
def volatile_variable(*args, **kwargs):
if hasattr(torch, 'no_grad'):
with torch.no_grad():
return Variable(*args, **kwargs)
return Variable(*args, **kwargs, volatile=True)
# volatile has been deprecated, use the no_grad context manager instead
return Variable(*args, **kwargs)
else:
return Variable(*args, **kwargs, volatile=True)
def make_variable(sample, volatile=False, cuda_device=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment