"vscode:/vscode.git/clone" did not exist on "f24a6f5988dc5330cd22acbafba64d5937c922ce"
Commit e0c65b0d authored by chenpangpang's avatar chenpangpang
Browse files

Add new file

parent b2d7fb87
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
import time
# 检查JAX是否使用了GPU
print("JAX版本:", jax.__version__)
print("可用设备:", jax.devices())
print("当前使用的设备:", jax.devices()[0])
print("设备类型:", jax.devices()[0].device_kind)
# 1. 基本的GPU矩阵运算示例
def basic_gpu_operations():
print("\n=== 基础GPU矩阵运算 ===")
# 创建大型矩阵(这些会自动在GPU上分配)
key = random.PRNGKey(42)
size = 5000 # 大型矩阵
# 生成随机矩阵
A = random.normal(key, (size, size))
B = random.normal(key, (size, size))
# GPU矩阵乘法
start = time.time()
C = jnp.dot(A, B) # 这会自动在GPU上执行
compute_time = time.time() - start
print(f"矩阵大小: {size}x{size}")
print(f"GPU矩阵乘法耗时: {compute_time:.4f}秒")
# 检查结果
print(f"结果矩阵C的形状: {C.shape}")
print(f"C的平均值: {jnp.mean(C):.6f}")
print(f"C的标准差: {jnp.std(C):.6f}")
return C
# 2. 使用JIT编译加速
def jit_acceleration():
print("\n=== JIT编译加速示例 ===")
# 定义一个复杂函数
def complex_function(x):
return jnp.sum(jnp.sin(x) * jnp.exp(-x**2) + jnp.log1p(jnp.abs(x)))
# 创建输入数据
key = random.PRNGKey(0)
x = random.normal(key, (10000, 1000))
# 不使用JIT
start = time.time()
result_no_jit = complex_function(x)
time_no_jit = time.time() - start
# 使用JIT编译
jitted_function = jit(complex_function)
# 第一次调用会有编译开销
start = time.time()
result_jit = jitted_function(x)
time_first_call = time.time() - start
# 第二次调用(已经编译)
start = time.time()
result_jit = jitted_function(x)
time_second_call = time.time() - start
print(f"不使用JIT: {time_no_jit:.4f}秒")
print(f"JIT首次调用(含编译): {time_first_call:.4f}秒")
print(f"JIT后续调用: {time_second_call:.4f}秒")
print(f"加速比: {time_no_jit/time_second_call:.2f}倍")
return result_jit
# 3. 自动微分和GPU计算
def autodiff_gpu():
print("\n=== GPU上的自动微分 ===")
# 定义损失函数
def loss(params, x, y):
w, b = params
predictions = jnp.dot(x, w) + b
return jnp.mean((predictions - y)**2)
# 生成模拟数据
key = random.PRNGKey(1)
key1, key2, key3 = random.split(key, 3)
n_samples = 10000
n_features = 100
X = random.normal(key1, (n_samples, n_features))
true_w = random.normal(key2, (n_features, 1))
true_b = random.normal(key3, (1,))
Y = jnp.dot(X, true_w) + true_b + 0.1 * random.normal(key1, (n_samples, 1))
# 初始化参数
params = (random.normal(key1, (n_features, 1)), jnp.zeros((1,)))
# 计算梯度的GPU函数
grad_loss = jit(grad(loss, argnums=0))
# 在GPU上计算梯度
start = time.time()
gradients = grad_loss(params, X, Y)
grad_time = time.time() - start
print(f"数据形状: X={X.shape}, Y={Y.shape}")
print(f"GPU梯度计算耗时: {grad_time:.4f}秒")
print(f"权重梯度形状: {gradients[0].shape}")
print(f"偏置梯度形状: {gradients[1].shape}")
return gradients
# 4. 批量处理示例
def batch_processing():
print("\n=== GPU批量处理 ===")
# 定义一个处理单个样本的函数
def process_sample(x):
# 一些复杂的变换
return jnp.sum(jnp.tanh(x) * jnp.cos(x))
# 使用vmap进行向量化
batched_process = vmap(process_sample)
# 创建批量数据
key = random.PRNGKey(3)
batch_size = 10000
sample_size = 1000
batch_data = random.normal(key, (batch_size, sample_size))
# 逐个处理(慢)
def sequential_process(data):
results = []
for i in range(data.shape[0]):
results.append(process_sample(data[i]))
return jnp.array(results)
# 计时
start = time.time()
sequential_results = sequential_process(batch_data)
sequential_time = time.time() - start
# 使用vmap批量处理(在GPU上并行)
start = time.time()
batched_results = batched_process(batch_data)
batched_time = time.time() - start
print(f"批量大小: {batch_size}, 样本大小: {sample_size}")
print(f"逐个处理耗时: {sequential_time:.4f}秒")
print(f"批量处理耗时: {batched_time:.4f}秒")
print(f"加速比: {sequential_time/batched_time:.2f}倍")
# 验证结果一致性
error = jnp.max(jnp.abs(sequential_results - batched_results))
print(f"结果最大误差: {error:.10f}")
return batched_results
# 5. 高级示例:简单的神经网络层
def neural_network_example():
print("\n=== GPU上的神经网络层 ===")
# 定义一个简单的神经网络层
def dense_layer(params, x):
w, b = params
return jnp.dot(x, w) + b
# ReLU激活函数
def relu(x):
return jnp.maximum(0, x)
# 两层神经网络
def two_layer_net(params, x):
w1, b1, w2, b2 = params
hidden = relu(jnp.dot(x, w1) + b1)
output = jnp.dot(hidden, w2) + b2
return output
# 初始化参数
key = random.PRNGKey(4)
key1, key2, key3, key4 = random.split(key, 4)
input_size = 784 # MNIST图像大小
hidden_size = 128
output_size = 10 # 10个类别
# 初始化权重
w1 = random.normal(key1, (input_size, hidden_size)) * jnp.sqrt(2.0 / input_size)
b1 = jnp.zeros(hidden_size)
w2 = random.normal(key2, (hidden_size, output_size)) * jnp.sqrt(2.0 / hidden_size)
b2 = jnp.zeros(output_size)
params = (w1, b1, w2, b2)
# 创建批量输入
batch_size = 1024
x_batch = random.normal(key3, (batch_size, input_size))
# JIT编译神经网络
jitted_net = jit(two_layer_net)
# 前向传播
start = time.time()
outputs = jitted_net(params, x_batch)
inference_time = time.time() - start
print(f"网络架构: {input_size} -> {hidden_size} -> {output_size}")
print(f"批量大小: {batch_size}")
print(f"GPU推理耗时: {inference_time:.6f}秒")
print(f"输出形状: {outputs.shape}")
print(f"每个样本的平均输出: {jnp.mean(outputs, axis=0)}")
return outputs
def main():
print("="*50)
print("JAX GPU Demo - 开始执行")
print("="*50)
# 运行所有示例
try:
# 示例1: 基础GPU运算
result1 = basic_gpu_operations()
# 示例2: JIT加速
result2 = jit_acceleration()
# 示例3: 自动微分
result3 = autodiff_gpu()
# 示例4: 批量处理
result4 = batch_processing()
# 示例5: 神经网络
result5 = neural_network_example()
print("\n" + "="*50)
print("所有示例执行完成!")
print("="*50)
# 验证GPU确实在工作
print("\n验证GPU使用情况:")
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
# 执行一个操作来查看跟踪
test_matrix = jnp.ones((1000, 1000))
test_result = jnp.dot(test_matrix, test_matrix.T)
print(f"跟踪测试完成,结果形状: {test_result.shape}")
except Exception as e:
print(f"执行出错: {e}")
print("\n注意:请确保已经正确安装JAX GPU版本")
print("安装命令: pip install --upgrade 'jax[cuda12_pip]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment