Commit 2ce87038 authored by zms1999's avatar zms1999
Browse files

skip computeFn when micro_batch_size == 0

parent 59bcec8e
......@@ -80,6 +80,9 @@ void computeFn(py::function fn, c10::Device device,
scalar_t* inp_buf, scalar_t* out_buf,
long idx, long offset, long micro_batch_size, long d_model,
CudaStreamManager* smgr) {
if(micro_batch_size == 0) {
return;
}
auto options = torch::TensorOptions()
.dtype(c10::CppTypeToScalarType<scalar_t>::value)
.device(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment