Commit d9d7d1a4 authored by thomwolf's avatar thomwolf
Browse files

update float()

parent c6207d85
...@@ -1015,7 +1015,7 @@ ...@@ -1015,7 +1015,7 @@
" print(input_mask)\n", " print(input_mask)\n",
" print(example_indices)\n", " print(example_indices)\n",
" input_ids = input_ids.to(device)\n", " input_ids = input_ids.to(device)\n",
" input_mask = input_mask.float().to(device)\n", " input_mask = input_mask.to(device)\n",
"\n", "\n",
" all_encoder_layers, _ = model(input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)\n", " all_encoder_layers, _ = model(input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)\n",
"\n", "\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment