if(half_to_float)AT_ASSERTM(input_.type().scalarType()==ScalarType::Half||input_.type().scalarType()==ScalarType::BFloat16,"conversion is supported for Half and BFloat16 type only");
AT_ASSERTM(labels_.type().scalarType()==ScalarType::Long,"Label type should be CUDA Long");
AT_ASSERTM((grad_loss.type().scalarType()==ScalarType::Float&&(logits.type().scalarType()==ScalarType::Half||logits.type().scalarType()==ScalarType::BFloat16)),"expected input and grad types to match, or input to be at::Half or at::Bfloat16 and grad to be at::Float");