"examples/vscode:/vscode.git/clone" did not exist on "cf36f4d7a8741a74fd156fe1016fe49dff38da1a"
Unverified Commit b936582f authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing flaky conversational test + flag it as a pipeline test. (#9837)

parent 58fbef9e
......@@ -22,7 +22,7 @@ from transformers import (
is_torch_available,
pipeline,
)
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import is_pipeline_test, require_torch, slow, torch_device
from .test_pipelines_common import MonoInputPipelineCommonMixin
......@@ -35,6 +35,7 @@ if is_torch_available():
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
@is_pipeline_test
class SimpleConversationPipelineTests(unittest.TestCase):
def get_pipeline(self):
# When
......@@ -52,9 +53,11 @@ class SimpleConversationPipelineTests(unittest.TestCase):
# Force model output to be L
V, D = model.lm_head.weight.shape
bias = torch.zeros(V, requires_grad=True)
weight = torch.zeros((V, D), requires_grad=True)
bias[76] = 1
model.lm_head.bias = torch.nn.Parameter(bias)
model.lm_head.weight = torch.nn.Parameter(weight)
# # Created with:
# import tempfile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment