Commit da77a836 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into api_change

parents 397714f5 b31e1296
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "ATen/ATen.h" #include "ATen/ATen.h"
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/CUDAContext.h"
#include <THC/THCDeviceUtils.cuh> #include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
......
...@@ -392,14 +392,18 @@ def get_bias_dropout_add(training): ...@@ -392,14 +392,18 @@ def get_bias_dropout_add(training):
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob): def bias_dropout_add_fused_train(x: torch.Tensor,
# type: (Tensor, Tensor, Tensor, float) -> Tensor bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True) return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob): def bias_dropout_add_fused_inference(x: torch.Tensor,
# type: (Tensor, Tensor, Tensor, float) -> Tensor bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False) return bias_dropout_add(x, bias, residual, prob, False)
......
...@@ -70,9 +70,10 @@ class MegatronGenerate(Resource): ...@@ -70,9 +70,10 @@ class MegatronGenerate(Resource):
temperature = args.temperature temperature = args.temperature
if "temperature" in request.get_json(): if "temperature" in request.get_json():
temperature = request.get_json()["temperature"] temperature = request.get_json()["temperature"]
if not isinstance(temperature, float) or not \ if not (type(temperature) == int or type(temperature) == float):
0.0 < temperature <= 100.0: return "temperature must be a positive number less than or equal to 100.0"
return "temperature must be a positive float less than or equal to 100.0" if not (0.0 < temperature <= 100.0):
return "temperature must be a positive number less than or equal to 100.0"
top_k = args.top_k top_k = args.top_k
if "top_k" in request.get_json(): if "top_k" in request.get_json():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment