Commit 838af7d9 authored by Sangkug Lym's avatar Sangkug Lym
Browse files

jit function argument type specification

parent 386923b5
...@@ -392,14 +392,18 @@ def get_bias_dropout_add(training): ...@@ -392,14 +392,18 @@ def get_bias_dropout_add(training):
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob): def bias_dropout_add_fused_train(x: torch.Tensor,
# type: (Tensor, Tensor, Tensor, float) -> Tensor bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True) return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob): def bias_dropout_add_fused_inference(x: torch.Tensor,
# type: (Tensor, Tensor, Tensor, float) -> Tensor bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False) return bias_dropout_add(x, bias, residual, prob, False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment