if(half_to_float)AT_ASSERTM(input_.type().scalarType()==ScalarType::Half,"conversion is supported for Half type only");
if(half_to_float)AT_ASSERTM(input_.type().scalarType()==ScalarType::Half||input_.type().scalarType()==ScalarType::BFloat16,"conversion is supported for Half and BFloat16 type only");
AT_ASSERTM(labels_.type().scalarType()==ScalarType::Long,"Label type should be CUDA Long");
AT_ASSERTM(labels_.type().scalarType()==ScalarType::Long,"Label type should be CUDA Long");
AT_ASSERTM((grad_loss.type().scalarType()==ScalarType::Float&&logits.type().scalarType()==ScalarType::Half),"expected input and grad types to match, or input to be at::Half and grad to be at::Float");
AT_ASSERTM((grad_loss.type().scalarType()==ScalarType::Float&&(logits.type().scalarType()==ScalarType::Half||logits.type().scalarType()==ScalarType::BFloat16)),"expected input and grad types to match, or input to be at::Half or at::Bfloat16 and grad to be at::Float");
raiseRuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
raiseRuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
raiseRuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
raiseRuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
else:
# Check, if CUDA11 is installed for compute capability 8.0
# Check, if CUDA11 is installed for compute capability 8.0