Commit 83c5f3b8 authored by gushiqiao's avatar gushiqiao Committed by Yang Yong(雍洋)
Browse files

Fix bugs

parent 683aaa3a
......@@ -22,17 +22,17 @@ class LNWeightTemplate(metaclass=ABCMeta):
if config is not None:
self.config = config
def to_cpu(self):
def to_cpu(self, non_blocking=False):
if self.weight is not None:
self.weight = self.weight.cpu()
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.cpu()
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
def to_cuda(self):
def to_cuda(self, non_blocking=False):
if self.weight is not None:
self.weight = self.weight.cuda()
self.weight = self.weight.cuda(non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.cuda()
self.bias = self.bias.cuda(non_blocking=non_blocking)
@LN_WEIGHT_REGISTER("Default")
......
......@@ -21,11 +21,11 @@ class RMSWeightTemplate(metaclass=ABCMeta):
if config is not None:
self.config = config
def to_cpu(self):
self.weight = self.weight.cpu()
def to_cpu(self, non_blocking=False):
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
def to_cuda(self):
self.weight = self.weight.cuda()
def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking)
@RMS_WEIGHT_REGISTER("Default")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment