Unverified Commit b07770c5 authored by Anton Vlasjuk's avatar Anton Vlasjuk Committed by GitHub
Browse files

[`GPT-NeoX`] Add SDPA support (#31031)

* starting support for sdpa in `gptneox` models

* small comment on tests

* fix dropout

* documentation and style

* clarify concrete paths for reference

* generalise attn projections and rope application

added head mask check to sdpa mask creation

handle sdpa memory backend bug via own version flag

* update docs and style

* move dtype casting outside of general attn_projection_and_rope function

fix flash_attn_2 stuff

* more generic attn warning if output_attns or head_mask

* simplify head mask check by moving head mask creation to a later point

* remove copied llama artifact

* remove padding_mask from attention function signature

* removing unnecessary comments, only "save" attn implementation once

* [run_slow] gpt_neox
parent 1218e439
......@@ -95,6 +95,68 @@ Below is an expected speedup diagram that compares pure inference time between t
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/gpt-neox-1.8b-speedup.jpg">
</div>
## Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```python
from transformers import GPTNeoXForCausalLM
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, attn_implementation="sdpa")
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (rtx3080ti-16GB, PyTorch 2.2.1, OS Ubuntu 22.04) using `float16` with
[pythia-410m-deduped](https://huggingface.co/EleutherAI/pythia-410m-deduped), we saw the
following speedups during training and inference.
### Training
| Batch size | Seq len | Time per batch (Eager - s) | Time per batch (SDPA - s) | Speedup (%) | Eager peak mem (MB) | SDPA peak mem (MB) | Mem saving (%) |
|-----------:|-----------:|---------------------------:|-----------------------------:|------------:|--------------------:|-------------------:|------------------:|
| 1 | 128 | 0.024 | 0.019 | 28.945 | 1789.95 | 1789.95 | 0 |
| 1 | 256 | 0.039 | 0.031 | 23.18 | 1845.83 | 1844.84 | 0.053 |
| 1 | 512 | 0.08 | 0.055 | 45.524 | 2278.38 | 1953.76 | 16.615 |
| 1 | 1024 | 0.19 | 0.102 | 86.777 | 4772.36 | 2408.35 | 98.159 |
| 1 | 2048 | 0.565 | 0.204 | 177.098 | 13484.1 | 3882.01 | 247.348 |
| 2 | 128 | 0.037 | 0.032 | 15.121 | 1843.86 | 1844.78 | -0.05 |
| 2 | 256 | 0.067 | 0.055 | 21.706 | 1999.72 | 1951.67 | 2.462 |
| 2 | 512 | 0.144 | 0.096 | 50.046 | 3613.16 | 2406.77 | 50.125 |
| 2 | 1024 | 0.366 | 0.193 | 89.666 | 8707.55 | 3878.86 | 124.487 |
| 2 | 2048 | OOM | 0.379 | / | OOM | 6825.13 | SDPA does not OOM |
| 4 | 128 | 0.06 | 0.054 | 11.539 | 1947.6 | 1952.06 | -0.228 |
| 4 | 256 | 0.119 | 0.093 | 28.072 | 3008.39 | 2405.99 | 25.038 |
| 4 | 512 | 0.275 | 0.187 | 47.145 | 6290.58 | 3877.29 | 62.242 |
| 4 | 1024 | OOM | 0.36 | / | OOM | 6821.98 | SDPA does not OOM |
| 4 | 2048 | OOM | 0.731 | / | OOM | 12705.1 | SDPA does not OOM |
### Inference
| Batch size | Seq len | Per token latency Eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem Eager (MB) | Mem SDPA (MB) | Mem saved (%) |
|--------------:|-------------:|--------------------------------:|-------------------------------:|---------------:|------------------:|----------------:|-----------------:|
| 1 | 128 | 6.569 | 5.858 | 12.14 | 974.831 | 974.826 | 0 |
| 1 | 256 | 7.009 | 5.863 | 19.542 | 1029.01 | 1028.08 | 0.09 |
| 1 | 512 | 7.157 | 5.965 | 19.983 | 1137.54 | 1137.52 | 0.001 |
| 1 | 1024 | 7.523 | 6.506 | 15.637 | 1329.3 | 1329.26 | 0.003 |
| 1 | 2048 | 9.271 | 9.205 | 0.713 | 1752.47 | 1734.51 | 1.036 |
| 2 | 128 | 7.239 | 5.959 | 21.493 | 1044.8 | 1028.37 | 1.597 |
| 2 | 256 | 7.228 | 6.036 | 19.757 | 1167.32 | 1137.73 | 2.601 |
| 2 | 512 | 7.538 | 6.693 | 12.628 | 1352.93 | 1329.55 | 1.758 |
| 2 | 1024 | 8.916 | 8.632 | 3.291 | 1752.56 | 1734.62 | 1.034 |
| 2 | 2048 | 12.628 | 12.606 | 0.181 | 2558.72 | 2545.8 | 0.508 |
| 4 | 128 | 7.278 | 6.046 | 20.373 | 1168.41 | 1137.79 | 2.691 |
| 4 | 256 | 7.614 | 6.588 | 15.574 | 1353.1 | 1329.79 | 1.753 |
| 4 | 512 | 8.798 | 8.144 | 8.028 | 1752.76 | 1734.85 | 1.032 |
| 4 | 1024 | 11.765 | 11.303 | 4.09 | 2558.96 | 2546.04 | 0.508 |
| 4 | 2048 | 19.568 | 17.735 | 10.33 | 4175.5 | 4165.26 | 0.246 |
## Resources
- [Causal language modeling task guide](../tasks/language_modeling)
......
......@@ -203,6 +203,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
......
......@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
......@@ -29,6 +30,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
......@@ -37,7 +39,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
from ...utils import get_torch_version, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
from .configuration_gpt_neox import GPTNeoXConfig
......@@ -78,6 +80,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTNeoXLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -162,7 +165,56 @@ class GPTNeoXAttention(nn.Module):
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
padding_mask: Optional[torch.Tensor] = None,
):
# Apply attention-specific projections and rope
query, key, value, present = self._attn_projections_and_rope(
hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
)
# Compute attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# Reshape outputs
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
attn_output = self.dense(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
@classmethod
def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
# tensor: [bs, seq_len, hidden_size]
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(new_shape)
# -> [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3)
return tensor
@classmethod
def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3).contiguous()
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size)
# -> [bs, seq_len, hidden_size]
return tensor
def _attn_projections_and_rope(
self,
hidden_states: torch.FloatTensor,
position_ids: torch.LongTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
):
has_layer_past = layer_past is not None
......@@ -204,43 +256,7 @@ class GPTNeoXAttention(nn.Module):
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
# Compute attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# Reshape outputs
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
attn_output = self.dense(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
@classmethod
def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
# tensor: [bs, seq_len, hidden_size]
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(new_shape)
# -> [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3)
return tensor
@classmethod
def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3).contiguous()
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size)
# -> [bs, seq_len, hidden_size]
return tensor
return query, key, value, present
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
......@@ -319,48 +335,13 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention):
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.query_key_value(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
qkv = qkv.view(*new_qkv_shape)
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
# Apply attention-specific projections and rope
query, key, value, present = self._attn_projections_and_rope(
hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
)
query_length = query.shape[-2]
# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.rotary_ndims]
query_pass = query[..., self.rotary_ndims :]
key_rot = key[..., : self.rotary_ndims]
key_pass = key[..., self.rotary_ndims :]
# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)
# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
# GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
target_dtype = value.dtype
if query.dtype != target_dtype:
......@@ -516,6 +497,90 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention):
)
class GPTNeoXSdpaAttention(GPTNeoXAttention):
"""
GPTNeoX attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`GPTNeoXAttention` as the weights of the module stays untouched. The only changes are on the forward pass
to adapt to the SDPA API.
"""
def __init__(self, config):
super().__init__(config)
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
head_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
if output_attentions or head_mask is not None:
logger.warning_once(
"`GPTNeoXSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
)
bsz, q_len, _ = hidden_states.size()
# Apply attention-specific projections and rope
query, key, value, present = self._attn_projections_and_rope(
hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
)
# GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
target_dtype = value.dtype
if query.dtype != target_dtype:
query = query.to(target_dtype)
if key.dtype != target_dtype:
key = key.to(target_dtype)
# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if attention_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attention_mask,
dropout_p=self.attention_dropout.p if self.training else 0.0,
is_causal=is_causal,
)
# Reshape outputs
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.dense(attn_output)
return attn_output, present, None
def attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
return attention_scores
......@@ -660,6 +725,7 @@ class GPTNeoXMLP(nn.Module):
GPT_NEOX_ATTENTION_CLASSES = {
"eager": GPTNeoXAttention,
"flash_attention_2": GPTNeoXFlashAttention2,
"sdpa": GPTNeoXSdpaAttention,
}
......@@ -786,7 +852,8 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
self.emb_dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._attn_implementation = config._attn_implementation
self.gradient_checkpointing = False
......@@ -859,27 +926,29 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)
# Attention mask.
if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1)
if self._use_flash_attention_2:
if self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif self._attn_implementation == "sdpa" and not output_attentions and head_mask is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
else:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
......@@ -888,9 +957,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)
hidden_states = self.emb_dropout(inputs_embeds)
if self.gradient_checkpointing and self.training:
......
......@@ -19,7 +19,7 @@ import unittest
from parameterized import parameterized
from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available, set_seed
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import require_torch, require_torch_sdpa, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
......@@ -396,6 +396,68 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
"""
Based on tests.models.llama.test_modeling_llama.LlamaModelTest.test_eager_matches_sdpa_generate
which also overwrites the common test as the test is flaky on tiny models.
"""
max_new_tokens = 30
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1b")
model_sdpa = GPTNeoXForCausalLM.from_pretrained(
"EleutherAI/pythia-1b",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = GPTNeoXForCausalLM.from_pretrained(
"EleutherAI/pythia-1b",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
texts = [
"hi here's a longer context, getting longer and",
"Hello this is a very long sentence my friend, very long for real",
"Today I am in Paris and",
]
for padding_side in ["left", "right"]:
tokenizer.padding_side = padding_side
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
with self.subTest(f"{padding_side}"):
torch.testing.assert_close(
res_eager,
res_sdpa,
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
)
@require_torch
class GPTNeoXLanguageGenerationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment