"docs/source/vscode:/vscode.git/clone" did not exist on "98e409abb3a91e4de2c27e1c729ba9fdaa74d3df"
Unverified Commit d9a81fc0 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

fix func signature (#10271)

parent c6fe1755
...@@ -146,7 +146,7 @@ if is_torch_available(): ...@@ -146,7 +146,7 @@ if is_torch_available():
self.double_output = double_output self.double_output = double_output
self.config = None self.config = None
def forward(self, input_x=None, labels=None, **kwargs): def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b y = input_x * self.a + self.b
if labels is None: if labels is None:
return (y, y) if self.double_output else (y,) return (y, y) if self.double_output else (y,)
...@@ -160,7 +160,7 @@ if is_torch_available(): ...@@ -160,7 +160,7 @@ if is_torch_available():
self.b = torch.nn.Parameter(torch.tensor(b).float()) self.b = torch.nn.Parameter(torch.tensor(b).float())
self.config = None self.config = None
def forward(self, input_x=None, labels=None, **kwargs): def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b y = input_x * self.a + self.b
result = {"output": y} result = {"output": y}
if labels is not None: if labels is not None:
...@@ -177,7 +177,7 @@ if is_torch_available(): ...@@ -177,7 +177,7 @@ if is_torch_available():
self.b = torch.nn.Parameter(torch.tensor(config.b).float()) self.b = torch.nn.Parameter(torch.tensor(config.b).float())
self.double_output = config.double_output self.double_output = config.double_output
def forward(self, input_x=None, labels=None, **kwargs): def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b y = input_x * self.a + self.b
if labels is None: if labels is None:
return (y, y) if self.double_output else (y,) return (y, y) if self.double_output else (y,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment